Louis BECQUEY

more resolution-specific statistics

#!/usr/bin/python3.8
import Bio
import Bio.PDB as pdb
import concurrent.futures
import getopt
import gzip
......@@ -25,7 +26,8 @@ from multiprocessing import Pool, Manager
from time import sleep
from tqdm import tqdm
from setproctitle import setproctitle
from Bio import AlignIO, SeqIO
from Bio.Align import AlignInfo
def trace_unhandled_exceptions(func):
@wraps(func)
......@@ -112,7 +114,7 @@ class SelectivePortionSelector(object):
return 1
class BufferingSummaryInfo(Bio.Align.AlignInfo.SummaryInfo):
class BufferingSummaryInfo(AlignInfo.SummaryInfo):
def get_pssm(self, family, index):
"""Create a position specific score matrix object for the alignment.
......@@ -139,7 +141,7 @@ class BufferingSummaryInfo(Bio.Align.AlignInfo.SummaryInfo):
score_dict[this_residue] = 1.0
pssm_info.append(('*', score_dict))
return Bio.Align.AlignInfo.PSSM(pssm_info)
return AlignInfo.PSSM(pssm_info)
class Chain:
......@@ -198,11 +200,11 @@ class Chain:
with warnings.catch_warnings():
# Ignore the PDB problems. This mostly warns that some chain is discontinuous.
warnings.simplefilter('ignore', Bio.PDB.PDBExceptions.PDBConstructionWarning)
warnings.simplefilter('ignore', Bio.PDB.PDBExceptions.BiopythonWarning)
warnings.simplefilter('ignore', pdb.PDBExceptions.PDBConstructionWarning)
warnings.simplefilter('ignore', pdb.PDBExceptions.BiopythonWarning)
# Load the whole mmCIF into a Biopython structure object:
mmcif_parser = Bio.PDB.MMCIFParser()
mmcif_parser = pdb.MMCIFParser()
try:
s = mmcif_parser.get_structure(self.pdb_id, path_to_3D_data + "RNAcifs/"+self.pdb_id+".cif")
except ValueError as e:
......@@ -223,7 +225,7 @@ class Chain:
sel = SelectivePortionSelector(model_idx, self.pdb_chain_id, valid_set, khetatm)
# Save that selection on the mmCIF object s to file
ioobj = Bio.PDB.mmcifio.MMCIFIO()
ioobj = pdb.MMCIFIO()
ioobj.set_structure(s)
ioobj.save(self.file, sel)
......@@ -1115,7 +1117,7 @@ class Pipeline:
print(f"nohup bash -c 'time {fileDir}/RNAnet.py --3d-folder ~/Data/RNA/3D/ --seq-folder ~/Data/RNA/sequences -s' &")
sys.exit()
elif opt == '--version':
print("RNANet 1.1 beta")
print("RNANet 1.2, parallelized, Dockerized")
sys.exit()
elif opt == "-r" or opt == "--resolution":
assert float(arg) > 0.0 and float(arg) <= 20.0
......@@ -1445,7 +1447,7 @@ class Pipeline:
# Update the database
data = []
for r in results:
align = Bio.AlignIO.read(path_to_seq_data + "realigned/" + r[0] + "++.afa", "fasta")
align = AlignIO.read(path_to_seq_data + "realigned/" + r[0] + "++.afa", "fasta")
nb_3d_chains = len([1 for r in align if '[' in r.id])
if r[0] in SSU_set: # SSU v138 is used
nb_homologs = 2225272 # source: https://www.arb-silva.de/documentation/release-138/
......@@ -1535,9 +1537,9 @@ class Pipeline:
# Run statistics
if self.RUN_STATS:
# Remove previous precomputed data
subprocess.run(["rm", "-f", runDir + "/data/wadley_kernel_eta.npz",
runDir + "/data/wadley_kernel_eta_prime.npz",
runDir + "/data/pair_counts.csv"])
subprocess.run(["rm", "-f", runDir + f"/data/wadley_kernel_eta_{self.CRYSTAL_RES}.npz",
runDir + f"/data/wadley_kernel_eta_prime_{self.CRYSTAL_RES}.npz",
runDir + f"/data/pair_counts_{self.CRYSTAL_RES}.csv"])
for f in self.fam_list:
subprocess.run(["rm", "-f", runDir + f"/data/{f}.npy",
runDir + f"/data/{f}_pairs.csv",
......@@ -2124,7 +2126,7 @@ def work_mmcif(pdb_id):
# if not, read the CIF header and register the structure
if not len(r):
# Load the MMCIF file with Biopython
mmCif_info = Bio.PDB.MMCIF2Dict.MMCIF2Dict(final_filepath)
mmCif_info = pdb.MMCIF2Dict.MMCIF2Dict(final_filepath)
# Get info about that structure
try:
......@@ -2218,7 +2220,7 @@ def work_prepare_sequences(dl, rfam_acc, chains):
if rfam_acc in LSU_set | SSU_set: # rRNA
if os.path.isfile(path_to_seq_data + f"realigned/{rfam_acc}++.afa"):
# Detect doublons and remove them
existing_afa = Bio.AlignIO.read(path_to_seq_data + f"realigned/{rfam_acc}++.afa", "fasta")
existing_afa = AlignIO.read(path_to_seq_data + f"realigned/{rfam_acc}++.afa", "fasta")
existing_ids = [r.id for r in existing_afa]
del existing_afa
new_ids = [str(c) for c in chains]
......@@ -2227,7 +2229,7 @@ def work_prepare_sequences(dl, rfam_acc, chains):
if len(doublons):
warn(f"Removing {len(doublons)} doublons from existing {rfam_acc}++.fa and using their newest version")
fasta = path_to_seq_data + f"realigned/{rfam_acc}++.fa"
seqfile = Bio.SeqIO.parse(fasta, "fasta")
seqfile = SeqIO.parse(fasta, "fasta")
# remove it and rewrite it with its own content filtered
os.remove(fasta)
with open(fasta, 'w') as f:
......@@ -2268,7 +2270,7 @@ def work_prepare_sequences(dl, rfam_acc, chains):
with open(path_to_seq_data + f"realigned/{rfam_acc}++.fa", "w") as plusplus:
ids = set()
# Remove doublons from the Rfam hits
for r in Bio.SeqIO.parse(path_to_seq_data + f"realigned/{rfam_acc}.fa", "fasta"):
for r in SeqIO.parse(path_to_seq_data + f"realigned/{rfam_acc}.fa", "fasta"):
if r.id not in ids:
ids.add(r.id)
plusplus.write('> '+r.description+'\n'+str(r.seq)+'\n')
......@@ -2343,10 +2345,10 @@ def work_realign(rfam_acc):
notify("Aligned new sequences together")
# Detect doublons and remove them
existing_stk = Bio.AlignIO.read(existing_ali_path, "stockholm")
existing_stk = AlignIO.read(existing_ali_path, "stockholm")
existing_ids = [r.id for r in existing_stk]
del existing_stk
new_stk = Bio.AlignIO.read(new_ali_path, "stockholm")
new_stk = AlignIO.read(new_ali_path, "stockholm")
new_ids = [r.id for r in new_stk]
del new_stk
doublons = [i for i in existing_ids if i in new_ids]
......@@ -2447,7 +2449,7 @@ def work_pssm(f, fill_gaps):
# Open the alignment
try:
align = Bio.AlignIO.read(path_to_seq_data + f"realigned/{f}++.afa", "fasta")
align = AlignIO.read(path_to_seq_data + f"realigned/{f}++.afa", "fasta")
except:
warn(f"{f}'s alignment is wrong. Recompute it and retry.", error=True)
with open(runDir + "/errors.txt", "a") as errf:
......
......@@ -70,7 +70,7 @@ def reproduce_wadley_results(carbon=4, show=False, sd_range=(1,4), res=2.0):
thr_idx = idxQueue.get()
setproctitle(f"RNANet statistics.py Worker {thr_idx+1} reproduce_wadley_results(carbon={carbon})")
pbar = tqdm(total=2, desc=f"Worker {thr_idx+1}: eta/theta C{carbon} kernels", position=thr_idx+1, leave=False)
pbar = tqdm(total=2, desc=f"Worker {thr_idx+1}: eta/theta C{carbon} kernels", unit="kernel", position=thr_idx+1, leave=False)
# Extract the angle values of c2'-endo and c3'-endo nucleotides
with sqlite3.connect(runDir + "/results/RNANet.db") as conn:
......@@ -203,25 +203,10 @@ def stats_len():
global idxQueue
thr_idx = idxQueue.get()
# sort the RNA families so that the plot is readable
def family_order(f):
if f in LSU_set:
return 4
elif f in SSU_set:
return 3
elif f in ["RF00001"]: #
return 1 # put tRNAs and 5S rRNAs first,
elif f in ["RF00005"]: # because of the logarithmic scale, otherwise, they look tiny
return 0 #
else:
return 2
fam_list.sort(key=family_order)
cols = []
lengths = []
for f in tqdm(fam_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Average chain lengths", leave=False):
for f in tqdm(famlist, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Average chain lengths", unit="family", leave=False):
# Define a color for that family in the plot
if f in LSU_set:
......@@ -249,7 +234,7 @@ def stats_len():
# Plot the figure
fig = plt.figure(figsize=(10,3))
ax = fig.gca()
ax.hist(lengths, bins=100, stacked=True, log=True, color=cols, label=fam_list)
ax.hist(lengths, bins=100, stacked=True, log=True, color=cols, label=famlist)
ax.set_xlabel("Sequence length (nucleotides)", fontsize=8)
ax.set_ylabel("Number of 3D chains", fontsize=8)
ax.set_xlim(left=-150)
......@@ -303,18 +288,18 @@ def stats_freq():
# Initialize a Counter object for each family
freqs = {}
for f in fam_list:
for f in famlist:
freqs[f] = Counter()
# List all nt_names happening within a RNA family and store the counts in the Counter
for f in tqdm(fam_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Base frequencies", leave=False):
for f in tqdm(famlist, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Base frequencies", unit="family", leave=False):
with sqlite3.connect(runDir + "/results/RNANet.db") as conn:
counts = dict(sql_ask_database(conn, f"SELECT nt_name, COUNT(nt_name) FROM (SELECT chain_id from chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY nt_name;", warn_every=0))
freqs[f].update(counts)
# Create a pandas DataFrame, and save it to CSV.
df = pd.DataFrame()
for f in tqdm(fam_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Base frequencies", leave=False):
for f in tqdm(famlist, position=thr_idx+1, desc=f"Worker {thr_idx+1}: Base frequencies", unit="family", leave=False):
tot = sum(freqs[f].values())
df = pd.concat([ df, pd.DataFrame([[ format_percentage(tot, x) for x in freqs[f].values() ]], columns=list(freqs[f]), index=[f]) ])
df = df.fillna(0)
......@@ -322,12 +307,13 @@ def stats_freq():
idxQueue.put(thr_idx) # replace the thread index in the queue
# notify("Saved nucleotide frequencies to CSV file.")
@trace_unhandled_exceptions
def parallel_stats_pairs(f):
"""Counts occurrences of intra-chain base-pair types in one RNA family
REQUIRES tables chain, nucleotide up-to-date."""
if path.isfile("data/"+f+"_pairs.csv") and path.isfile("data/"+f+"_counts.csv"):
if path.isfile(runDir + "/data/"+f+"_pairs.csv") and path.isfile(runDir + "/data/"+f+"_counts.csv"):
return
# Get a worker number to position the progress bar
......@@ -339,7 +325,7 @@ def parallel_stats_pairs(f):
chain_id_list = mappings_list[f]
data = []
sqldata = []
for cid in tqdm(chain_id_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f} basepair types", leave=False):
for cid in tqdm(chain_id_list, position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f} basepair types", unit="chain",leave=False):
with sqlite3.connect(runDir + "/results/RNANet.db") as conn:
# Get comma separated lists of basepairs per nucleotide
interactions = pd.DataFrame(
......@@ -430,16 +416,19 @@ def parallel_stats_pairs(f):
idxQueue.put(thr_idx) # replace the thread index in the queue
def to_dist_matrix(f):
def to_id_matrix(f):
"""
Extracts sequences of 3D chains from the family alignments to a distinct STK file,
then runs esl-alipid on it to get an identity matrix
"""
if path.isfile("data/"+f+".npy"):
# notify(f"Computed {f} distance matrix", "loaded from file")
return 0
# Get a worker number to position the progress bar
global idxQueue
thr_idx = idxQueue.get()
setproctitle(f"RNANet statistics.py Worker {thr_idx+1} to_dist_matrix({f})")
setproctitle(f"RNANet statistics.py Worker {thr_idx+1} to_id_matrix({f})")
# Prepare a file
with open(path_to_seq_data+f"/realigned/{f}++.afa") as al_file:
......@@ -452,14 +441,16 @@ def to_dist_matrix(f):
except ValueError as e:
warn(e)
del al
subprocess.run(["esl-reformat", "--informat", "stockholm", "--mingap", "-o", path_to_seq_data+f"/realigned/{f}_3d_only.stk", "stockholm", path_to_seq_data+f"/realigned/{f}_3d_only_tmp.stk"])
subprocess.run(["esl-reformat", "--informat", "stockholm", "--mingap", #
"-o", path_to_seq_data+f"/realigned/{f}_3d_only.stk", # This run just deletes columns of gaps
"stockholm", path_to_seq_data+f"/realigned/{f}_3d_only_tmp.stk"]) #
subprocess.run(["rm", "-f", f + "_3d_only_tmp.stk"])
# Prepare the job
process = subprocess.Popen(shlex.split(f"esl-alipid --rna --noheader --informat stockholm {path_to_seq_data}realigned/{f}_3d_only.stk"),
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
id_matrix = np.zeros((len(names), len(names)))
pbar = tqdm(total = len(names)*(len(names)-1)*0.5, position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f} idty matrix", leave=False)
pbar = tqdm(total = len(names)*(len(names)-1)*0.5, position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f} idty matrix", unit="comparisons", leave=False)
cnt = 0
while not cnt or process.poll() is None:
output = process.stdout.read()
......@@ -482,8 +473,8 @@ def to_dist_matrix(f):
warn("\n".join([ line.decode('utf-8') for line in l ]))
pbar.close()
subprocess.run(["rm", "-f", f + "_3d_only_tmp.stk"])
np.save("data/"+f+".npy", id_matrix)
idxQueue.put(thr_idx) # replace the thread index in the queue
return 0
......@@ -545,7 +536,7 @@ def seq_idty():
fig.tight_layout()
fig.subplots_adjust(hspace=0.3, wspace=0.1)
fig.colorbar(im, ax=axs[-4], shrink=0.8)
fig.savefig(runDir + f"/results/figures/distances.png")
fig.savefig(runDir + f"/results/figures/distances_{res_thr}.png")
print("> Computed all identity matrices and saved the figure.", flush=True)
def stats_pairs():
......@@ -559,10 +550,10 @@ def stats_pairs():
def line_format(family_data):
return family_data.apply(partial(format_percentage, sum(family_data)))
if not path.isfile("data/pair_counts.csv"):
if not path.isfile("data/pair_counts_{res_thr}.csv"):
results = []
allpairs = []
for f in fam_list:
for f in famlist:
newpairs = pd.read_csv(runDir + f"/data/{f}_pairs.csv", index_col=0)
fam_df = pd.read_csv(runDir + f"/data/{f}_counts.csv", index_col=0)
results.append(fam_df)
......@@ -571,11 +562,11 @@ def stats_pairs():
subprocess.run(["rm", "-f", runDir + f"/data/{f}_counts.csv"])
all_pairs = pd.concat(allpairs)
df = pd.concat(results).fillna(0)
df.to_csv("data/pair_counts.csv")
all_pairs.to_csv("data/all_pairs.csv")
df.to_csv(runDir + f"/data/pair_counts_{res_thr}.csv")
all_pairs.to_csv(runDir + f"/data/all_pairs_{res_thr}.csv")
else:
df = pd.read_csv("data/pair_counts.csv", index_col=0)
all_pairs = pd.read_csv("data/all_pairs.csv", index_col=0)
df = pd.read_csv(runDir + f"/data/pair_counts_{res_thr}.csv", index_col=0)
all_pairs = pd.read_csv(runDir + f"/data/all_pairs_{res_thr}.csv", index_col=0)
crosstab = pd.crosstab(all_pairs.pair_type_LW, all_pairs.basepair)
col_list = [ x for x in df.columns if '.' in x ]
......@@ -613,7 +604,7 @@ def stats_pairs():
ax.set_ylabel("Number of observations (millions)", fontsize=13)
ax.set_xlabel(None)
plt.subplots_adjust(left=0.1, bottom=0.16, top=0.95, right=0.99)
plt.savefig(runDir + "/results/figures/pairings.png")
plt.savefig(runDir + f"/results/figures/pairings_{res_thr}.png")
notify("Computed nucleotide statistics and saved CSV and PNG file.")
......@@ -916,8 +907,24 @@ def log_to_pbar(pbar):
pbar.update(1)
return update
def family_order(f):
# sort the RNA families so that the plots are readable
if f in LSU_set:
return 4
elif f in SSU_set:
return 3
elif f in ["RF00001"]: #
return 1 # put tRNAs and 5S rRNAs first,
elif f in ["RF00005"]: # because of the logarithmic scale of the lengths' figure, otherwise, they look tiny
return 0 #
else:
return 2
if __name__ == "__main__":
os.makedirs(runDir + "/results/figures/", exist_ok=True)
# parse options
DELETE_OLD_DATA = False
DO_WADLEY_ANALYSIS = False
......@@ -943,7 +950,7 @@ if __name__ == "__main__":
print("--from-scratch\t\t\tDo not use precomputed results from past runs, recompute everything")
sys.exit()
elif opt == '--version':
print("RNANet statistics 1.1 beta")
print("RNANet statistics 1.2")
sys.exit()
elif opt == "-r" or opt == "--resolution":
assert float(arg) > 0.0 and float(arg) <= 20.0
......@@ -959,31 +966,38 @@ if __name__ == "__main__":
elif opt=='--from-scratch':
DELETE_OLD_DATA = True
DO_WADLEY_ANALYSIS = True
subprocess.run(["rm","-f", "data/wadley_kernel_eta.npz", "data/wadley_kernel_eta_prime.npz", "data/pair_counts.csv"])
elif opt=='--wadley':
DO_WADLEY_ANALYSIS = True
# Load mappings
# Load mappings. famlist will contain only families with structures at this resolution threshold.
print("Loading mappings list...")
with sqlite3.connect(runDir + "/results/RNANet.db") as conn:
fam_list = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from family ORDER BY rfam_acc ASC;") ]
mappings_list = {}
for k in fam_list:
mappings_list[k] = [ x[0] for x in sql_ask_database(conn, f"SELECT chain_id from chain JOIN structure ON chain.structure_id=structure.pdb_id WHERE rfam_acc='{k}' AND issue=0 AND resolution <= {res_thr};") ]
# List the families for which we will compute sequence identity matrices
with sqlite3.connect(runDir + "/results/RNANet.db") as conn:
famlist = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from (SELECT rfam_acc, COUNT(chain_id) as n_chains FROM family NATURAL JOIN chain WHERE issue = 0 GROUP BY rfam_acc) WHERE n_chains > 0 ORDER BY rfam_acc ASC;") ]
ignored = [ x[0] for x in sql_ask_database(conn, "SELECT rfam_acc from (SELECT rfam_acc, COUNT(chain_id) as n_chains FROM family NATURAL JOIN chain WHERE issue = 0 GROUP BY rfam_acc) WHERE n_chains < 3 ORDER BY rfam_acc ASC;") ]
n_unmapped_chains = sql_ask_database(conn, "SELECT COUNT(*) FROM chain WHERE rfam_acc='unmappd' AND issue=0;")[0][0]
families = pd.read_sql(f"""SELECT rfam_acc, count(*) as n_chains
FROM chain JOIN structure
ON chain.structure_id = structure.pdb_id
WHERE issue = 0 AND resolution <= {res_thr} AND rfam_acc != 'unmappd'
GROUP BY rfam_acc;
""", conn)
families.drop(families[families.n_chains == 0].index, inplace=True)
mappings_list = {}
for k in families.rfam_acc:
mappings_list[k] = [ x[0] for x in sql_ask_database(conn, f"""SELECT chain_id
FROM chain JOIN structure ON chain.structure_id=structure.pdb_id
WHERE rfam_acc='{k}' AND issue=0 AND resolution <= {res_thr};""") ]
famlist = families.rfam_acc.tolist()
ignored = families[families.n_chains < 3].rfam_acc.tolist()
famlist.sort(key=family_order)
print(f"Found {len(famlist)} families with chains of resolution {res_thr}A or better.")
if len(ignored):
print(f"Idty matrices: Ignoring {len(ignored)} families with only one chain:", " ".join(ignored)+'\n')
if DELETE_OLD_DATA:
for f in fam_list:
for f in famlist:
subprocess.run(["rm","-f", runDir + f"/data/{f}.npy", runDir + f"/data/{f}_pairs.csv", runDir + f"/data/{f}_counts.csv"])
if DO_WADLEY_ANALYSIS:
subprocess.run(["rm","-f", runDir + f"/data/wadley_kernel_eta_{res_thr}.npz", runDir + f"/data/wadley_kernel_eta_prime_{res_thr}.npz", runDir + f"/data/pair_counts_{res_thr}.csv"])
# Prepare the multiprocessing execution environment
nworkers = min(read_cpu_number()-1, 32)
......@@ -995,17 +1009,17 @@ if __name__ == "__main__":
# Define the tasks
joblist = []
if n_unmapped_chains and DO_WADLEY_ANALYSIS:
joblist.append(Job(function=reproduce_wadley_results, args=(1, False, (1,4), 20.0))) # res threshold is 4.0 Angstroms by default
joblist.append(Job(function=reproduce_wadley_results, args=(4, False, (1,4), 20.0))) #
joblist.append(Job(function=reproduce_wadley_results, args=(1, False, (1,4), res_thr)))
joblist.append(Job(function=reproduce_wadley_results, args=(4, False, (1,4), res_thr)))
joblist.append(Job(function=stats_len)) # Computes figures
joblist.append(Job(function=stats_freq)) # updates the database
for f in famlist:
joblist.append(Job(function=parallel_stats_pairs, args=(f,))) # updates the database
if f not in ignored:
joblist.append(Job(function=to_dist_matrix, args=(f,))) # updates the database
joblist.append(Job(function=to_id_matrix, args=(f,))) # updates the database
p = Pool(initializer=init_worker, initargs=(tqdm.get_lock(),), processes=nworkers)
pbar = tqdm(total=len(joblist), desc="Stat jobs", position=0, leave=True)
pbar = tqdm(total=len(joblist), desc="Stat jobs", position=0, unit="job", leave=True)
try:
for j in joblist:
......