Louis BECQUEY

Better parallel statistics computation

...@@ -1638,6 +1638,7 @@ def sql_ask_database(conn, sql, warn_every = 10): ...@@ -1638,6 +1638,7 @@ def sql_ask_database(conn, sql, warn_every = 10):
1638 1638
1639 @trace_unhandled_exceptions 1639 @trace_unhandled_exceptions
1640 def sql_execute(conn, sql, many=False, data=None, warn_every=10): 1640 def sql_execute(conn, sql, many=False, data=None, warn_every=10):
1641 + conn.execute('pragma journal_mode=wal') # Allow multiple other readers to ask things while we execute this writing query
1641 for _ in range(100): # retry 100 times if it fails 1642 for _ in range(100): # retry 100 times if it fails
1642 try: 1643 try:
1643 if many: 1644 if many:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
3 # Run RNANet 3 # Run RNANet
4 cd /home/lbecquey/Projects/RNANet; 4 cd /home/lbecquey/Projects/RNANet;
5 rm -f stdout.txt stderr.txt errors.txt; 5 rm -f stdout.txt stderr.txt errors.txt;
6 -time './RNAnet.py --3d-folder /home/lbequey/Data/RNA/3D/ --seq-folder /home/lbecquey/Data/RNA/sequences/ -s -r 20.0' > stdout.txt 2> stderr.txt; 6 +time './RNAnet.py --3d-folder /home/lbequey/Data/RNA/3D/ --seq-folder /home/lbecquey/Data/RNA/sequences/ -s -r 20.0 --archive' > stdout.txt 2> stderr.txt;
7 7
8 # Sync in Seafile 8 # Sync in Seafile
9 seaf-cli start; 9 seaf-cli start;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
5 # in the database. 5 # in the database.
6 # This should be run from the folder where the file is (to access the database with path "results/RNANet.db") 6 # This should be run from the folder where the file is (to access the database with path "results/RNANet.db")
7 7
8 -import os, pickle, sqlite3, sys 8 +import os, pickle, sqlite3, shlex, subprocess, sys
9 import numpy as np 9 import numpy as np
10 import pandas as pd 10 import pandas as pd
11 import threading as th 11 import threading as th
...@@ -16,14 +16,13 @@ import matplotlib.patches as mpatches ...@@ -16,14 +16,13 @@ import matplotlib.patches as mpatches
16 import scipy.cluster.hierarchy as sch 16 import scipy.cluster.hierarchy as sch
17 from scipy.spatial.distance import squareform 17 from scipy.spatial.distance import squareform
18 from mpl_toolkits.mplot3d import axes3d 18 from mpl_toolkits.mplot3d import axes3d
19 -from Bio.Phylo.TreeConstruction import DistanceCalculator
20 from Bio import AlignIO, SeqIO 19 from Bio import AlignIO, SeqIO
21 from functools import partial 20 from functools import partial
22 -from multiprocessing import Pool 21 +from multiprocessing import Pool, Manager
23 from os import path 22 from os import path
24 from tqdm import tqdm 23 from tqdm import tqdm
25 from collections import Counter 24 from collections import Counter
26 -from RNAnet import read_cpu_number, sql_ask_database, sql_execute, warn, notify, init_worker 25 +from RNAnet import Job, read_cpu_number, sql_ask_database, sql_execute, warn, notify, init_worker
27 26
28 # This sets the paths 27 # This sets the paths
29 if len(sys.argv) > 1: 28 if len(sys.argv) > 1:
...@@ -37,7 +36,7 @@ else: ...@@ -37,7 +36,7 @@ else:
37 LSU_set = ("RF00002", "RF02540", "RF02541", "RF02543", "RF02546") # From Rfam CLAN 00112 36 LSU_set = ("RF00002", "RF02540", "RF02541", "RF02543", "RF02546") # From Rfam CLAN 00112
38 SSU_set = ("RF00177", "RF02542", "RF02545", "RF01959", "RF01960") # From Rfam CLAN 00111 37 SSU_set = ("RF00177", "RF02542", "RF02545", "RF01959", "RF01960") # From Rfam CLAN 00111
39 38
40 -def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): 39 +def reproduce_wadley_results(carbon=4, show=False, sd_range=(1,4)):
41 """ 40 """
42 Plot the joint distribution of pseudotorsion angles, in a Ramachandran-style graph. 41 Plot the joint distribution of pseudotorsion angles, in a Ramachandran-style graph.
43 See Wadley & Pyle (2007) 42 See Wadley & Pyle (2007)
...@@ -68,6 +67,12 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): ...@@ -68,6 +67,12 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)):
68 67
69 68
70 if not path.isfile(f"data/wadley_kernel_{angle}.npz"): 69 if not path.isfile(f"data/wadley_kernel_{angle}.npz"):
70 +
71 + # Get a worker number to position the progress bar
72 + global idxQueue
73 + thr_idx = idxQueue.get()
74 + pbar = tqdm(total=2, desc=f"Worker {thr_idx+1}: eta/theta C{carbon} kernels", position=thr_idx+1, leave=False)
75 +
71 # Extract the angle values of c2'-endo and c3'-endo nucleotides 76 # Extract the angle values of c2'-endo and c3'-endo nucleotides
72 with sqlite3.connect("results/RNANet.db") as conn: 77 with sqlite3.connect("results/RNANet.db") as conn:
73 df = pd.read_sql(f"""SELECT {angle}, th{angle} FROM nucleotide WHERE puckering="C2'-endo" AND {angle} IS NOT NULL AND th{angle} IS NOT NULL;""", conn) 78 df = pd.read_sql(f"""SELECT {angle}, th{angle} FROM nucleotide WHERE puckering="C2'-endo" AND {angle} IS NOT NULL AND th{angle} IS NOT NULL;""", conn)
...@@ -89,13 +94,17 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): ...@@ -89,13 +94,17 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)):
89 xx, yy = np.mgrid[0:2*np.pi:100j, 0:2*np.pi:100j] 94 xx, yy = np.mgrid[0:2*np.pi:100j, 0:2*np.pi:100j]
90 positions = np.vstack([xx.ravel(), yy.ravel()]) 95 positions = np.vstack([xx.ravel(), yy.ravel()])
91 f_c3 = np.reshape(kernel_c3(positions).T, xx.shape) 96 f_c3 = np.reshape(kernel_c3(positions).T, xx.shape)
97 + pbar.update(1)
92 f_c2 = np.reshape(kernel_c2(positions).T, xx.shape) 98 f_c2 = np.reshape(kernel_c2(positions).T, xx.shape)
99 + pbar.update(1)
93 100
94 # Save the data to an archive for later use without the need to recompute 101 # Save the data to an archive for later use without the need to recompute
95 np.savez(f"data/wadley_kernel_{angle}.npz", 102 np.savez(f"data/wadley_kernel_{angle}.npz",
96 c3_endo_e=c3_endo_etas, c3_endo_t=c3_endo_thetas, 103 c3_endo_e=c3_endo_etas, c3_endo_t=c3_endo_thetas,
97 c2_endo_e=c2_endo_etas, c2_endo_t=c2_endo_thetas, 104 c2_endo_e=c2_endo_etas, c2_endo_t=c2_endo_thetas,
98 kernel_c3=f_c3, kernel_c2=f_c2) 105 kernel_c3=f_c3, kernel_c2=f_c2)
106 + pbar.close()
107 + idxQueue.put(thr_idx)
99 else: 108 else:
100 f = np.load(f"data/wadley_kernel_{angle}.npz") 109 f = np.load(f"data/wadley_kernel_{angle}.npz")
101 c2_endo_etas = f["c2_endo_e"] 110 c2_endo_etas = f["c2_endo_e"]
...@@ -106,7 +115,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): ...@@ -106,7 +115,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)):
106 f_c2 = f["kernel_c2"] 115 f_c2 = f["kernel_c2"]
107 xx, yy = np.mgrid[0:2*np.pi:100j, 0:2*np.pi:100j] 116 xx, yy = np.mgrid[0:2*np.pi:100j, 0:2*np.pi:100j]
108 117
109 - notify(f"Kernel computed for {angle}/th{angle} (or loaded from file).") 118 + # notify(f"Kernel computed for {angle}/th{angle} (or loaded from file).")
110 119
111 # exact counts: 120 # exact counts:
112 hist_c2, xedges, yedges = np.histogram2d(c2_endo_etas, c2_endo_thetas, bins=int(2*np.pi/0.1), 121 hist_c2, xedges, yedges = np.histogram2d(c2_endo_etas, c2_endo_thetas, bins=int(2*np.pi/0.1),
...@@ -139,7 +148,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): ...@@ -139,7 +148,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)):
139 fig.savefig(f"results/figures/wadley_plots/wadley_hist_{angle}_{l}.png") 148 fig.savefig(f"results/figures/wadley_plots/wadley_hist_{angle}_{l}.png")
140 if show: 149 if show:
141 fig.show() 150 fig.show()
142 - fig.close() 151 + plt.close()
143 152
144 # Smoothed joint distribution 153 # Smoothed joint distribution
145 fig = plt.figure() 154 fig = plt.figure()
...@@ -150,7 +159,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): ...@@ -150,7 +159,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)):
150 fig.savefig(f"results/figures/wadley_plots/wadley_distrib_{angle}_{l}.png") 159 fig.savefig(f"results/figures/wadley_plots/wadley_distrib_{angle}_{l}.png")
151 if show: 160 if show:
152 fig.show() 161 fig.show()
153 - fig.close() 162 + plt.close()
154 163
155 # 2D Wadley plot 164 # 2D Wadley plot
156 fig = plt.figure(figsize=(5,5)) 165 fig = plt.figure(figsize=(5,5))
...@@ -163,7 +172,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)): ...@@ -163,7 +172,7 @@ def reproduce_wadley_results(show=False, carbon=4, sd_range=(1,4)):
163 fig.savefig(f"results/figures/wadley_plots/wadley_{angle}_{l}.png") 172 fig.savefig(f"results/figures/wadley_plots/wadley_{angle}_{l}.png")
164 if show: 173 if show:
165 fig.show() 174 fig.show()
166 - fig.close() 175 + plt.close()
167 # print(f"[{worker_nbr}]\tComputed joint distribution of angles (C{carbon}) and saved the figures.") 176 # print(f"[{worker_nbr}]\tComputed joint distribution of angles (C{carbon}) and saved the figures.")
168 177
169 def stats_len(): 178 def stats_len():
...@@ -171,11 +180,15 @@ def stats_len(): ...@@ -171,11 +180,15 @@ def stats_len():
171 180
172 REQUIRES tables chain, nucleotide up to date. 181 REQUIRES tables chain, nucleotide up to date.
173 """ 182 """
183 +
184 + # Get a worker number to position the progress bar
185 + global idxQueue
186 + thr_idx = idxQueue.get()
174 187
175 cols = [] 188 cols = []
176 lengths = [] 189 lengths = []
177 - conn = sqlite3.connect("results/RNANet.db") 190 +
178 - for i,f in enumerate(fam_list): 191 + for i,f in enumerate(tqdm(fam_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Average chain lengths", leave=False)):
179 192
180 # Define a color for that family in the plot 193 # Define a color for that family in the plot
181 if f in LSU_set: 194 if f in LSU_set:
...@@ -190,11 +203,11 @@ def stats_len(): ...@@ -190,11 +203,11 @@ def stats_len():
190 cols.append("grey") 203 cols.append("grey")
191 204
192 # Get the lengths of chains 205 # Get the lengths of chains
193 - l = [ x[0] for x in sql_ask_database(conn, f"SELECT COUNT(index_chain) FROM (SELECT chain_id FROM chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY chain_id;") ] 206 + with sqlite3.connect("results/RNANet.db") as conn:
207 + l = [ x[0] for x in sql_ask_database(conn, f"SELECT COUNT(index_chain) FROM (SELECT chain_id FROM chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY chain_id;", warn_every=0) ]
194 lengths.append(l) 208 lengths.append(l)
195 209
196 - notify(f"[{i+1}/{len(fam_list)}] Computed {f} chains lengths") 210 + # notify(f"[{i+1}/{len(fam_list)}] Computed {f} chains lengths")
197 - conn.close()
198 211
199 # Plot the figure 212 # Plot the figure
200 fig = plt.figure(figsize=(10,3)) 213 fig = plt.figure(figsize=(10,3))
...@@ -223,7 +236,8 @@ def stats_len(): ...@@ -223,7 +236,8 @@ def stats_len():
223 236
224 # Save the figure 237 # Save the figure
225 fig.savefig("results/figures/lengths.png") 238 fig.savefig("results/figures/lengths.png")
226 - notify("Computed sequence length statistics and saved the figure.") 239 + idxQueue.put(thr_idx) # replace the thread index in the queue
240 + # notify("Computed sequence length statistics and saved the figure.")
227 241
228 def format_percentage(tot, x): 242 def format_percentage(tot, x):
229 if not tot: 243 if not tot:
...@@ -242,40 +256,54 @@ def stats_freq(): ...@@ -242,40 +256,54 @@ def stats_freq():
242 256
243 Outputs results/frequencies.csv 257 Outputs results/frequencies.csv
244 REQUIRES tables chain, nucleotide up to date.""" 258 REQUIRES tables chain, nucleotide up to date."""
259 +
260 + # Get a worker number to position the progress bar
261 + global idxQueue
262 + thr_idx = idxQueue.get()
263 +
245 # Initialize a Counter object for each family 264 # Initialize a Counter object for each family
246 freqs = {} 265 freqs = {}
247 for f in fam_list: 266 for f in fam_list:
248 freqs[f] = Counter() 267 freqs[f] = Counter()
249 268
250 # List all nt_names happening within a RNA family and store the counts in the Counter 269 # List all nt_names happening within a RNA family and store the counts in the Counter
251 - conn = sqlite3.connect("results/RNANet.db") 270 + for i,f in enumerate(tqdm(fam_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Base frequencies", leave=False)):
252 - for i,f in enumerate(fam_list): 271 + with sqlite3.connect("results/RNANet.db") as conn:
253 - counts = dict(sql_ask_database(conn, f"SELECT nt_name, COUNT(nt_name) FROM (SELECT chain_id from chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY nt_name;")) 272 + counts = dict(sql_ask_database(conn, f"SELECT nt_name, COUNT(nt_name) FROM (SELECT chain_id from chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY nt_name;", warn_every=0))
254 freqs[f].update(counts) 273 freqs[f].update(counts)
255 - notify(f"[{i+1}/{len(fam_list)}] Computed {f} nucleotide frequencies.") 274 + # notify(f"[{i+1}/{len(fam_list)}] Computed {f} nucleotide frequencies.")
256 - conn.close()
257 275
258 # Create a pandas DataFrame, and save it to CSV. 276 # Create a pandas DataFrame, and save it to CSV.
259 df = pd.DataFrame() 277 df = pd.DataFrame()
260 - for f in fam_list: 278 + for f in tqdm(fam_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Base frequencies", leave=False):
261 tot = sum(freqs[f].values()) 279 tot = sum(freqs[f].values())
262 df = pd.concat([ df, pd.DataFrame([[ format_percentage(tot, x) for x in freqs[f].values() ]], columns=list(freqs[f]), index=[f]) ]) 280 df = pd.concat([ df, pd.DataFrame([[ format_percentage(tot, x) for x in freqs[f].values() ]], columns=list(freqs[f]), index=[f]) ])
263 df = df.fillna(0) 281 df = df.fillna(0)
264 df.to_csv("results/frequencies.csv") 282 df.to_csv("results/frequencies.csv")
265 - notify("Saved nucleotide frequencies to CSV file.") 283 + idxQueue.put(thr_idx) # replace the thread index in the queue
284 + # notify("Saved nucleotide frequencies to CSV file.")
266 285
267 def parallel_stats_pairs(f): 286 def parallel_stats_pairs(f):
268 """Counts occurrences of intra-chain base-pair types in one RNA family 287 """Counts occurrences of intra-chain base-pair types in one RNA family
269 288
270 REQUIRES tables chain, nucleotide up-to-date.""" 289 REQUIRES tables chain, nucleotide up-to-date."""
271 290
291 + # Get a worker number to position the progress bar
292 + global idxQueue
293 + thr_idx = idxQueue.get()
294 +
272 chain_id_list = mappings_list[f] 295 chain_id_list = mappings_list[f]
273 data = [] 296 data = []
274 - for cid in chain_id_list: 297 + sqldata = []
298 + for cid in tqdm(chain_id_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f} basepair types", leave=False):
275 with sqlite3.connect("results/RNANet.db") as conn: 299 with sqlite3.connect("results/RNANet.db") as conn:
276 # Get comma separated lists of basepairs per nucleotide 300 # Get comma separated lists of basepairs per nucleotide
277 - interactions = pd.read_sql(f"SELECT nt_code as nt1, index_chain, paired, pair_type_LW FROM (SELECT chain_id FROM chain WHERE chain_id='{cid}') NATURAL JOIN nucleotide;", conn) 301 + interactions = pd.DataFrame(
278 - 302 + sql_ask_database(conn,
303 + f"SELECT nt_code as nt1, index_chain, paired, pair_type_LW FROM (SELECT chain_id FROM chain WHERE chain_id='{cid}') NATURAL JOIN nucleotide;",
304 + warn_every=0),
305 + columns = ["nt1", "index_chain", "paired", "pair_type_LW"]
306 + )
279 # expand the comma-separated lists in real lists 307 # expand the comma-separated lists in real lists
280 expanded_list = pd.concat([ pd.DataFrame({ 'nt1':[ row["nt1"] for x in row["paired"].split(',') ], 308 expanded_list = pd.concat([ pd.DataFrame({ 'nt1':[ row["nt1"] for x in row["paired"].split(',') ],
281 'index_chain':[ row['index_chain'] for x in row["paired"].split(',') ], 309 'index_chain':[ row['index_chain'] for x in row["paired"].split(',') ],
...@@ -317,27 +345,29 @@ def parallel_stats_pairs(f): ...@@ -317,27 +345,29 @@ def parallel_stats_pairs(f):
317 345
318 # Update the database 346 # Update the database
319 vlcnts = expanded_list.pair_type_LW.value_counts() 347 vlcnts = expanded_list.pair_type_LW.value_counts()
320 - sqldata = ( vlcnts.at["cWW"]/2 if "cWW" in vlcnts.index else 0, 348 + sqldata.append( ( vlcnts.at["cWW"]/2 if "cWW" in vlcnts.index else 0,
321 - vlcnts.at["cWH"] if "cWH" in vlcnts.index else 0, 349 + vlcnts.at["cWH"] if "cWH" in vlcnts.index else 0,
322 - vlcnts.at["cWS"] if "cWS" in vlcnts.index else 0, 350 + vlcnts.at["cWS"] if "cWS" in vlcnts.index else 0,
323 - vlcnts.at["cHH"]/2 if "cHH" in vlcnts.index else 0, 351 + vlcnts.at["cHH"]/2 if "cHH" in vlcnts.index else 0,
324 - vlcnts.at["cHS"] if "cHS" in vlcnts.index else 0, 352 + vlcnts.at["cHS"] if "cHS" in vlcnts.index else 0,
325 - vlcnts.at["cSS"]/2 if "cSS" in vlcnts.index else 0, 353 + vlcnts.at["cSS"]/2 if "cSS" in vlcnts.index else 0,
326 - vlcnts.at["tWW"]/2 if "tWW" in vlcnts.index else 0, 354 + vlcnts.at["tWW"]/2 if "tWW" in vlcnts.index else 0,
327 - vlcnts.at["tWH"] if "tWH" in vlcnts.index else 0, 355 + vlcnts.at["tWH"] if "tWH" in vlcnts.index else 0,
328 - vlcnts.at["tWS"] if "tWS" in vlcnts.index else 0, 356 + vlcnts.at["tWS"] if "tWS" in vlcnts.index else 0,
329 - vlcnts.at["tHH"]/2 if "tHH" in vlcnts.index else 0, 357 + vlcnts.at["tHH"]/2 if "tHH" in vlcnts.index else 0,
330 - vlcnts.at["tHS"] if "tHS" in vlcnts.index else 0, 358 + vlcnts.at["tHS"] if "tHS" in vlcnts.index else 0,
331 - vlcnts.at["tSS"]/2 if "tSS" in vlcnts.index else 0, 359 + vlcnts.at["tSS"]/2 if "tSS" in vlcnts.index else 0,
332 - int(sum(vlcnts.loc[[ str(x) for x in vlcnts.index if "." in str(x)]])/2), 360 + int(sum(vlcnts.loc[[ str(x) for x in vlcnts.index if "." in str(x)]])/2),
333 - cid) 361 + cid) )
334 - with sqlite3.connect("results/RNANet.db") as conn:
335 - sql_execute(conn, """UPDATE chain SET pair_count_cWW = ?, pair_count_cWH = ?, pair_count_cWS = ?, pair_count_cHH = ?,
336 - pair_count_cHS = ?, pair_count_cSS = ?, pair_count_tWW = ?, pair_count_tWH = ?, pair_count_tWS = ?,
337 - pair_count_tHH = ?, pair_count_tHS = ?, pair_count_tSS = ?, pair_count_other = ? WHERE chain_id = ?;""", data=sqldata)
338 362
339 data.append(expanded_list) 363 data.append(expanded_list)
340 364
365 + # Update the database
366 + with sqlite3.connect("results/RNANet.db") as conn:
367 + conn.execute('pragma journal_mode=wal') # Allow multiple other readers to ask things while we execute this writing query
368 + sql_execute(conn, """UPDATE chain SET pair_count_cWW = ?, pair_count_cWH = ?, pair_count_cWS = ?, pair_count_cHH = ?,
369 + pair_count_cHS = ?, pair_count_cSS = ?, pair_count_tWW = ?, pair_count_tWH = ?, pair_count_tWS = ?,
370 + pair_count_tHH = ?, pair_count_tHS = ?, pair_count_tSS = ?, pair_count_other = ? WHERE chain_id = ?;""", many=True, data=sqldata, warn_every=0)
341 371
342 # merge all the dataframes from all chains of the family 372 # merge all the dataframes from all chains of the family
343 expanded_list = pd.concat(data) 373 expanded_list = pd.concat(data)
...@@ -351,7 +381,106 @@ def parallel_stats_pairs(f): ...@@ -351,7 +381,106 @@ def parallel_stats_pairs(f):
351 381
352 # Create an output DataFrame 382 # Create an output DataFrame
353 f_df = pd.DataFrame([[ x for x in cnt.values() ]], columns=list(cnt), index=[f]) 383 f_df = pd.DataFrame([[ x for x in cnt.values() ]], columns=list(cnt), index=[f])
354 - return expanded_list, f_df 384 + f_df.to_csv(f"data/{f}_counts.csv")
385 + expanded_list.to_csv(f"data/{f}_pairs.csv")
386 +
387 + idxQueue.put(thr_idx) # replace the thread index in the queue
388 +
389 +def to_dist_matrix(f):
390 + if path.isfile("data/"+f+".npy"):
391 + # notify(f"Computed {f} distance matrix", "loaded from file")
392 + return 0
393 +
394 + # Get a worker number to position the progress bar
395 + global idxQueue
396 + thr_idx = idxQueue.get()
397 +
398 + # notify(f"Computing {f} distance matrix from alignment...")
399 + command = f"esl-alipid --rna --noheader --informat stockholm {f}_3d_only.stk"
400 +
401 + # Prepare a file
402 + with open(path_to_seq_data+f"/realigned/{f}++.afa") as al_file:
403 + al = AlignIO.read(al_file, "fasta")
404 + names = [ x.id for x in al if '[' in x.id ]
405 + al = al[-len(names):]
406 + with open(f + "_3d_only.stk", "w") as only_3d:
407 + only_3d.write(al.format("stockholm"))
408 + del al
409 +
410 + # Prepare the job
411 + process = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE)
412 + id_matrix = np.zeros((len(names), len(names)))
413 +
414 + pbar = tqdm(total = len(names)*(len(names)-1)*0.5, position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f} idty matrix", leave=False)
415 + while process.poll() is None:
416 + output = process.stdout.readline()
417 + if output:
418 + lines = output.strip().split(b'\n')
419 + for l in lines:
420 + line = l.split()
421 + s1 = line[0].decode('utf-8')
422 + s2 = line[1].decode('utf-8')
423 + score = line[2].decode('utf-8')
424 + id1 = names.index(s1)
425 + id2 = names.index(s2)
426 + id_matrix[id1, id2] = float(score)
427 + pbar.update(1)
428 + pbar.close()
429 +
430 + subprocess.run(["rm", "-f", f + "_3d_only.stk"])
431 + np.save("data/"+f+".npy", id_matrix)
432 + idxQueue.put(thr_idx) # replace the thread index in the queue
433 + return 0
434 +
435 +def seq_idty():
436 + """Computes identity matrices for each of the RNA families.
437 +
438 + REQUIRES temporary results files in data/*.npy
439 + REQUIRES tables chain, family un to date."""
440 +
441 + # load distance matrices
442 + fam_arrays = []
443 + for f in famlist:
444 + if path.isfile("data/"+f+".npy"):
445 + fam_arrays.append(np.load("data/"+f+".npy"))
446 + else:
447 + fam_arrays.append([])
448 +
449 + # Update database with identity percentages
450 + conn = sqlite3.connect("results/RNANet.db")
451 + for f, D in zip(famlist, fam_arrays):
452 + if not len(D): continue
453 + a = 1.0 - np.average(D + D.T) # Get symmetric matrix instead of lower triangle + convert from distance matrix to identity matrix
454 + conn.execute(f"UPDATE family SET idty_percent = {round(float(a),2)} WHERE rfam_acc = '{f}';")
455 + conn.commit()
456 + conn.close()
457 +
458 + # Plots plots plots
459 + fig, axs = plt.subplots(4,17, figsize=(17,5.75))
460 + axs = axs.ravel()
461 + [axi.set_axis_off() for axi in axs]
462 + im = "" # Just to declare the variable, it will be set in the loop
463 + for f, D, ax in zip(famlist, fam_arrays, axs):
464 + if not len(D): continue
465 + if D.shape[0] > 2: # Cluster only if there is more than 2 sequences to organize
466 + D = D + D.T # Copy the lower triangle to upper, to get a symetrical matrix
467 + condensedD = squareform(D)
468 +
469 + # Compute basic dendrogram by Ward's method
470 + Y = sch.linkage(condensedD, method='ward')
471 + Z = sch.dendrogram(Y, orientation='left', no_plot=True)
472 +
473 + # Reorganize rows and cols
474 + idx1 = Z['leaves']
475 + D = D[idx1,:]
476 + D = D[:,idx1[::-1]]
477 + im = ax.matshow(1.0 - D, vmin=0, vmax=1, origin='lower') # convert to identity matrix 1 - D from distance matrix D
478 + ax.set_title(f + "\n(" + str(len(mappings_list[f]))+ " chains)", fontsize=10)
479 + fig.tight_layout()
480 + fig.subplots_adjust(wspace=0.1, hspace=0.3)
481 + fig.colorbar(im, ax=axs[-1], shrink=0.8)
482 + fig.savefig(f"results/figures/distances.png")
483 + notify("Computed all identity matrices and saved the figure.")
355 484
356 def stats_pairs(): 485 def stats_pairs():
357 """Counts occurrences of intra-chain base-pair types in RNA families 486 """Counts occurrences of intra-chain base-pair types in RNA families
...@@ -363,26 +492,15 @@ def stats_pairs(): ...@@ -363,26 +492,15 @@ def stats_pairs():
363 return family_data.apply(partial(format_percentage, sum(family_data))) 492 return family_data.apply(partial(format_percentage, sum(family_data)))
364 493
365 if not path.isfile("data/pair_counts.csv"): 494 if not path.isfile("data/pair_counts.csv"):
366 - p = Pool(initializer=init_worker, initargs=(tqdm.get_lock(),), processes=read_cpu_number(), maxtasksperchild=5) 495 + results = []
367 - try: 496 + allpairs = []
368 - fam_pbar = tqdm(total=len(fam_list), desc="Pair-types in families", position=0, leave=True) 497 + for f in fam_list:
369 - results = [] 498 + newpairs = pd.read_csv(f"data/{f}_pairs.csv", index_col=0)
370 - allpairs = [] 499 + fam_df = pd.read_csv(f"data/{f}_counts.csv", index_col=0)
371 - for _, newp_famdf in enumerate(p.imap_unordered(parallel_stats_pairs, fam_list)): 500 + results.append(fam_df)
372 - newpairs, fam_df = newp_famdf 501 + allpairs.append(newpairs)
373 - fam_pbar.update(1) 502 + subprocess.run(["rm", "-f", f"data/{f}_pairs.csv"])
374 - results.append(fam_df) 503 + subprocess.run(["rm", "-f", f"data/{f}_counts.csv"])
375 - allpairs.append(newpairs)
376 - fam_pbar.close()
377 - p.close()
378 - p.join()
379 - except KeyboardInterrupt:
380 - warn("KeyboardInterrupt, terminating workers.", error=True)
381 - fam_pbar.close()
382 - p.terminate()
383 - p.join()
384 - exit(1)
385 -
386 all_pairs = pd.concat(allpairs) 504 all_pairs = pd.concat(allpairs)
387 df = pd.concat(results).fillna(0) 505 df = pd.concat(results).fillna(0)
388 df.to_csv("data/pair_counts.csv") 506 df.to_csv("data/pair_counts.csv")
...@@ -431,86 +549,6 @@ def stats_pairs(): ...@@ -431,86 +549,6 @@ def stats_pairs():
431 549
432 notify("Computed nucleotide statistics and saved CSV and PNG file.") 550 notify("Computed nucleotide statistics and saved CSV and PNG file.")
433 551
434 -def to_dist_matrix(f):
435 - if path.isfile("data/"+f+".npy"):
436 - notify(f"Computed {f} distance matrix", "loaded from file")
437 - return 0
438 -
439 - notify(f"Computing {f} distance matrix from alignment...")
440 - dm = DistanceCalculator('identity')
441 - with open(path_to_seq_data+"/realigned/"+f+"++.afa") as al_file:
442 - al = AlignIO.read(al_file, "fasta")[-len(mappings_list[f]):]
443 - idty = dm.get_distance(al).matrix # list of lists
444 - del al
445 - l = len(idty)
446 - np.save("data/"+f+".npy", np.array([ idty[i] + [0]*(l-1-i) if i<l-1 else idty[i] for i in range(l) ], dtype=object))
447 - del idty
448 - notify(f"Computed {f} distance matrix")
449 - return 0
450 -
451 -def seq_idty():
452 - """Computes identity matrices for each of the RNA families.
453 -
454 - Creates temporary results files in data/*.npy
455 - REQUIRES tables chain, family un to date."""
456 -
457 - # List the families for which we will compute sequence identity matrices
458 - conn = sqlite3.connect("results/RNANet.db")
459 - famlist = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from (SELECT rfam_acc, COUNT(chain_id) as n_chains FROM family NATURAL JOIN chain GROUP BY rfam_acc) WHERE n_chains > 1 ORDER BY rfam_acc ASC;") ]
460 - ignored = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from (SELECT rfam_acc, COUNT(chain_id) as n_chains FROM family NATURAL JOIN chain GROUP BY rfam_acc) WHERE n_chains < 2 ORDER BY rfam_acc ASC;") ]
461 - if len(ignored):
462 - print(f"Idty matrices: Ignoring {len(ignored)} families with only one chain:", " ".join(ignored)+'\n')
463 -
464 - # compute distance matrices (or ignore if data/RF0****.npy exists)
465 - p = Pool(processes=8)
466 - p.map(to_dist_matrix, famlist)
467 - p.close()
468 - p.join()
469 -
470 - # load them
471 - fam_arrays = []
472 - for f in famlist:
473 - if path.isfile("data/"+f+".npy"):
474 - fam_arrays.append(np.load("data/"+f+".npy"))
475 - else:
476 - fam_arrays.append([])
477 -
478 - # Update database with identity percentages
479 - conn = sqlite3.connect("results/RNANet.db")
480 - for f, D in zip(famlist, fam_arrays):
481 - if not len(D): continue
482 - a = 1.0 - np.average(D + D.T) # Get symmetric matrix instead of lower triangle + convert from distance matrix to identity matrix
483 - conn.execute(f"UPDATE family SET idty_percent = {round(float(a),2)} WHERE rfam_acc = '{f}';")
484 - conn.commit()
485 - conn.close()
486 -
487 - # Plots plots plots
488 - fig, axs = plt.subplots(4,17, figsize=(17,5.75))
489 - axs = axs.ravel()
490 - [axi.set_axis_off() for axi in axs]
491 - im = "" # Just to declare the variable, it will be set in the loop
492 - for f, D, ax in zip(famlist, fam_arrays, axs):
493 - if not len(D): continue
494 - if D.shape[0] > 2: # Cluster only if there is more than 2 sequences to organize
495 - D = D + D.T # Copy the lower triangle to upper, to get a symetrical matrix
496 - condensedD = squareform(D)
497 -
498 - # Compute basic dendrogram by Ward's method
499 - Y = sch.linkage(condensedD, method='ward')
500 - Z = sch.dendrogram(Y, orientation='left', no_plot=True)
501 -
502 - # Reorganize rows and cols
503 - idx1 = Z['leaves']
504 - D = D[idx1,:]
505 - D = D[:,idx1[::-1]]
506 - im = ax.matshow(1.0 - D, vmin=0, vmax=1, origin='lower') # convert to identity matrix 1 - D from distance matrix D
507 - ax.set_title(f + "\n(" + str(len(mappings_list[f]))+ " chains)", fontsize=10)
508 - fig.tight_layout()
509 - fig.subplots_adjust(wspace=0.1, hspace=0.3)
510 - fig.colorbar(im, ax=axs[-1], shrink=0.8)
511 - fig.savefig(f"results/figures/distances.png")
512 - notify("Computed all identity matrices and saved the figure.")
513 -
514 def per_chain_stats(): 552 def per_chain_stats():
515 """Computes per-chain frequencies and base-pair type counts. 553 """Computes per-chain frequencies and base-pair type counts.
516 554
...@@ -524,39 +562,71 @@ def per_chain_stats(): ...@@ -524,39 +562,71 @@ def per_chain_stats():
524 df = df.drop("total", axis=1) 562 df = df.drop("total", axis=1)
525 563
526 # Set the values 564 # Set the values
565 + conn.execute('pragma journal_mode=wal')
527 sql_execute(conn, "UPDATE chain SET chain_freq_A = ?, chain_freq_C = ?, chain_freq_G = ?, chain_freq_U = ?, chain_freq_other = ? WHERE chain_id= ?;", 566 sql_execute(conn, "UPDATE chain SET chain_freq_A = ?, chain_freq_C = ?, chain_freq_G = ?, chain_freq_U = ?, chain_freq_other = ? WHERE chain_id= ?;",
528 many=True, data=list(df.to_records(index=False)), warn_every=10) 567 many=True, data=list(df.to_records(index=False)), warn_every=10)
529 notify("Updated the database with per-chain base frequencies") 568 notify("Updated the database with per-chain base frequencies")
530 569
570 +def log_to_pbar(pbar):
571 + def update(r):
572 + pbar.update(1)
573 + return update
574 +
531 if __name__ == "__main__": 575 if __name__ == "__main__":
532 576
533 os.makedirs("results/figures/wadley_plots/", exist_ok=True) 577 os.makedirs("results/figures/wadley_plots/", exist_ok=True)
534 578
535 print("Loading mappings list...") 579 print("Loading mappings list...")
536 - conn = sqlite3.connect("results/RNANet.db") 580 + with sqlite3.connect("results/RNANet.db") as conn:
537 - fam_list = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from family ORDER BY rfam_acc ASC;") ] 581 + fam_list = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from family ORDER BY rfam_acc ASC;") ]
538 - mappings_list = {} 582 + mappings_list = {}
539 - for k in fam_list: 583 + for k in fam_list:
540 - mappings_list[k] = [ x[0] for x in sql_ask_database(conn, f"SELECT chain_id from chain WHERE rfam_acc='{k}';") ] 584 + mappings_list[k] = [ x[0] for x in sql_ask_database(conn, f"SELECT chain_id from chain WHERE rfam_acc='{k}' and issue=0;") ]
541 - conn.close()
542 -
543 - # stats_pairs()
544 -
545 - # Define threads for the tasks
546 - threads = [
547 - th.Thread(target=reproduce_wadley_results, kwargs={'carbon': 1}),
548 - th.Thread(target=reproduce_wadley_results, kwargs={'carbon': 4}),
549 - th.Thread(target=stats_len), # computes figures
550 - th.Thread(target=stats_freq), # Updates the database
551 - th.Thread(target=seq_idty), # produces .npy files and seq idty figures
552 - th.Thread(target=per_chain_stats) # Updates the database
553 - ]
554 -
555 - # Start the threads
556 - for t in threads:
557 - t.start()
558 -
559 - # Wait for the threads to complete
560 - for t in threads:
561 - t.join()
562 585
586 + # List the families for which we will compute sequence identity matrices
587 + with sqlite3.connect("results/RNANet.db") as conn:
588 + famlist = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from (SELECT rfam_acc, COUNT(chain_id) as n_chains FROM family NATURAL JOIN chain GROUP BY rfam_acc) WHERE n_chains > 0 ORDER BY rfam_acc ASC;") ]
589 + ignored = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from (SELECT rfam_acc, COUNT(chain_id) as n_chains FROM family NATURAL JOIN chain GROUP BY rfam_acc) WHERE n_chains < 2 ORDER BY rfam_acc ASC;") ]
590 + if len(ignored):
591 + print(f"Idty matrices: Ignoring {len(ignored)} families with only one chain:", " ".join(ignored)+'\n')
592 +
593 + # Prepare the multiprocessing execution environment
594 + nworkers = max(read_cpu_number()-1, 32)
595 + thr_idx_mgr = Manager()
596 + idxQueue = thr_idx_mgr.Queue()
597 + for i in range(nworkers):
598 + idxQueue.put(i)
599 +
600 + # Define the tasks
601 + joblist = []
602 + joblist.append(Job(function=reproduce_wadley_results, args=(1,)))
603 + joblist.append(Job(function=reproduce_wadley_results, args=(4,)))
604 + joblist.append(Job(function=stats_len)) # Computes figures
605 + joblist.append(Job(function=stats_freq)) # updates the database
606 + for f in famlist:
607 + joblist.append(Job(function=parallel_stats_pairs, args=(f,))) # updates the database
608 + if f not in ignored:
609 + joblist.append(Job(function=to_dist_matrix, args=(f,))) # updates the database
610 +
611 + p = Pool(initializer=init_worker, initargs=(tqdm.get_lock(),), processes=nworkers)
612 + pbar = tqdm(total=len(joblist), desc="Stat jobs", position=0, leave=True)
613 +
614 + try:
615 + for j in joblist:
616 + p.apply_async(j.func_, args=j.args_, callback=log_to_pbar(pbar))
617 + p.close()
618 + p.join()
619 + pbar.close()
620 + except KeyboardInterrupt:
621 + warn("KeyboardInterrupt, terminating workers.", error=True)
622 + p.terminate()
623 + p.join()
624 + pbar.close()
625 + exit(1)
626 + except:
627 + print("Something went wrong")
628 +
629 + # finish the work after the parallel portions
630 + per_chain_stats()
631 + seq_idty()
632 + stats_pairs()
......