Louis BECQUEY

wrong line index in datapoint file when 1-hot encoding nts

Showing 1 changed file with 34 additions and 26 deletions
#!/usr/bin/python3.8
import numpy as np
import pandas as pd
import concurrent.futures, Bio.PDB.StructureBuilder, gzip, io, itertools, json, multiprocessing, os, psutil, re, requests, sqlalchemy, subprocess, sys, time, warnings
import concurrent.futures, Bio.PDB.StructureBuilder, gzip, io, itertools, json, os, psutil, re, requests, sqlalchemy, subprocess, sys, time, warnings
from Bio import AlignIO, SeqIO
from Bio.PDB import MMCIFParser, PDBIO
from Bio.PDB.mmcifio import MMCIFIO
......@@ -16,9 +16,10 @@ from collections import OrderedDict
from ftplib import FTP
from functools import partial
from os import path, makedirs
from multiprocessing import Pool, cpu_count, Manager
from multiprocessing import Pool, cpu_count, Manager, freeze_support, RLock, current_process
from time import sleep
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
if path.isdir("/home/ubuntu/"): # this is the IFB-core cloud
path_to_3D_data = "/mnt/Data/RNA/3D/"
......@@ -478,7 +479,7 @@ def read_cpu_number():
# it reads info from /sys wich is not the VM resources but the host resources.
# Read from /proc/cpuinfo instead.
p = subprocess.run(['grep', '-Ec', '(Intel|AMD)', '/proc/cpuinfo'], stdout=subprocess.PIPE)
return int(p.stdout.decode('utf-8')[:-1])/2
return int(int(p.stdout.decode('utf-8')[:-1])/2)
def warn(message, error=False):
if error:
......@@ -510,7 +511,7 @@ def execute_job(j, jobcount):
elif j.func_ is not None:
#print(f"[{running_stats[0]+running_stats[2]}/{jobcount}]\t{j.func_.__name__}({', '.join([str(a) for a in j.args_ if not ((type(a) == list) and len(a)>3)])})")
print(f"[{running_stats[0]+running_stats[2]}/{jobcount}]\t{j.func_.__name__}({', '.join([str(a) for a in j.args_ if not ((type(a) == list) and len(a)>3)])})")
m = -1
monitor = Monitor(os.getpid())
......@@ -771,15 +772,13 @@ def cm_realign(rfam_acc, chains, label):
subprocess.run(["esl-reformat", "afa", path_to_seq_data + f"realigned/{rfam_acc}++.stk"], stdout=f)
f.close()
# subprocess.run(["rm", path_to_seq_data + f"realigned/{rfam_acc}.cm", path_to_seq_data + f"realigned/{rfam_acc}++.fa", path_to_seq_data + f"realigned/{rfam_acc}++.stk"])
else:
# Ribosomal subunits deserve a special treatment.
# They require too much RAM to be aligned with Infernal.
# Then we will use SINA instead.
# Get the seed alignment from Rfam
print(f"\t> Download latest LSU-Ref alignment from SILVA...", end="", flush=True)
print(f"\t> Download latest LSU/SSU-Ref alignment from SILVA...", end="", flush=True)
if rfam_acc in ["RF02540", "RF02541", "RF02543"] and not path.isfile(path_to_seq_data + "realigned/LSU.arb"):
try:
_urlcleanup()
......@@ -819,6 +818,7 @@ def cm_realign(rfam_acc, chains, label):
"-o", path_to_seq_data + f"realigned/{rfam_acc}++.afa",
"-r", path_to_seq_data + arbfile,
"--meta-fmt=csv"])
return 0
def summarize_position(col):
# this function counts the number of nucleotides at a given position, given a "column" from a MSA.
......@@ -840,22 +840,21 @@ def summarize_position(col):
else:
return (0, 0, 0, 0, 0)
def alignment_nt_stats(f, list_of_chains) :
def alignment_nt_stats(f) :
global idxQueue
#print("\t>",f,"... ", flush=True)
list_of_chains = rfam_acc_to_download[f]
chains_ids = [ str(c) for c in list_of_chains ]
thr_idx = idxQueue.get()
# Open the alignment
align = AlignIO.read(path_to_seq_data + f"realigned/{f}++.afa", "fasta")
alilen = align.get_alignment_length()
#print("\t>",f,"... loaded", flush=True)
# Compute statistics per column
pbar = tqdm(total=alilen, position=thr_idx, desc=f"Worker { thr_idx}: {f}", leave=False, )
pbar = tqdm(iterable=range(alilen), position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f}", leave=False)
results = [ summarize_position(align[:,i]) for i in pbar ]
pbar.close()
frequencies = np.array(results).T
#print("\t>",f,"... loaded, computed", flush=True)
for s in align:
if not '[' in s.id: # this is a Rfamseq entry, not PDB
......@@ -868,10 +867,10 @@ def alignment_nt_stats(f, list_of_chains) :
# Save colums in the appropriate positions
i = 0
j = 0
warn_gaps = False
#warn_gaps = False
while i<c.full_length and j<alilen:
# here we try to map c.seq (the sequence of the 3D chain, including gaps when residues are missing),
# with s.seq, the sequence aligned in the MSA, containing any of ACGUacguP and two types of gaps, - and .
# with s.seq, the sequence aligned in the MSA, containing any of ACGUacgu and two types of gaps, - and .
if c.seq[i] == s.seq[j].upper(): # alignment and sequence correspond (incl. gaps)
list_of_chains[idx].frequencies = np.concatenate((list_of_chains[idx].frequencies, frequencies[:,j].reshape(-1,1)), axis=1)
......@@ -899,7 +898,7 @@ def alignment_nt_stats(f, list_of_chains) :
continue
# else, just ignore the gap.
warn_gaps = True
#warn_gaps = True
list_of_chains[idx].frequencies = np.concatenate((list_of_chains[idx].frequencies, np.array([0.0,0.0,0.0,0.0,1.0]).reshape(-1,1)), axis=1)
i += 1
elif s.seq[j] in ['.', '-']: # gap in the alignment, but not in the real chain
......@@ -927,7 +926,7 @@ def alignment_nt_stats(f, list_of_chains) :
# one-hot encoding of the actual sequence
if c.seq[i] in letters[:4]:
point[ letters[:4].index(c.seq[i]), i ] = 1
point[ 1 + letters[:4].index(c.seq[i]), i ] = 1
else:
point[5,i] = 1
......@@ -944,9 +943,8 @@ def alignment_nt_stats(f, list_of_chains) :
line = [str(x) for x in list(point[i,:]) ]
file.write(','.join(line)+'\n')
file.close()
#print("\t\tWritten", c.chain_label, f"to file\t{validsymb}", flush=True)
#print("\t>", f, f"... loaded, computed, saved\t{validsymb}", flush=True)
idxQueue.put(thr_idx) # replace the thread index
return 0
if __name__ == "__main__":
......@@ -1054,15 +1052,25 @@ if __name__ == "__main__":
families = sorted([f for f in rfam_acc_to_download.keys() ])
# Build job list
thr_idx_mgr = multiprocessing.Manager()
thr_idx_mgr = Manager()
idxQueue = thr_idx_mgr.Queue()
for i in range(10):
for i in range(read_cpu_number()):
idxQueue.put(i)
fulljoblist = []
for f in families:
label = f"Save {f} PSSMs"
list_of_chains = rfam_acc_to_download[f]
fulljoblist.append(Job(function=alignment_nt_stats, args=[f, list_of_chains], how_many_in_parallel=10, priority=1, label=label))
execute_joblist(fulljoblist, printstats=False)
# freeze_support()
# r = process_map(alignment_nt_stats, families, max_workers=read_cpu_number())
p = Pool(initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),), processes=read_cpu_number())
fam_pbar = tqdm(total=len(families), desc="RNA families", position=0, leave=True)
for i, _ in enumerate(p.imap_unordered(alignment_nt_stats, families)):
fam_pbar.update(1)
fam_pbar.close()
p.close()
p.join()
# ==========================================================================================
# Do a brief statistics summary of the produced dataset
# ==========================================================================================
#TODO: compute nt frequencies, chain lengths, angle clusters
print("Completed.")
......