Louis BECQUEY

wrong line index in datapoint file when 1-hot encoding nts

Showing 1 changed file with 34 additions and 26 deletions
1 #!/usr/bin/python3.8 1 #!/usr/bin/python3.8
2 import numpy as np 2 import numpy as np
3 import pandas as pd 3 import pandas as pd
4 -import concurrent.futures, Bio.PDB.StructureBuilder, gzip, io, itertools, json, multiprocessing, os, psutil, re, requests, sqlalchemy, subprocess, sys, time, warnings 4 +import concurrent.futures, Bio.PDB.StructureBuilder, gzip, io, itertools, json, os, psutil, re, requests, sqlalchemy, subprocess, sys, time, warnings
5 from Bio import AlignIO, SeqIO 5 from Bio import AlignIO, SeqIO
6 from Bio.PDB import MMCIFParser, PDBIO 6 from Bio.PDB import MMCIFParser, PDBIO
7 from Bio.PDB.mmcifio import MMCIFIO 7 from Bio.PDB.mmcifio import MMCIFIO
...@@ -16,9 +16,10 @@ from collections import OrderedDict ...@@ -16,9 +16,10 @@ from collections import OrderedDict
16 from ftplib import FTP 16 from ftplib import FTP
17 from functools import partial 17 from functools import partial
18 from os import path, makedirs 18 from os import path, makedirs
19 -from multiprocessing import Pool, cpu_count, Manager 19 +from multiprocessing import Pool, cpu_count, Manager, freeze_support, RLock, current_process
20 from time import sleep 20 from time import sleep
21 from tqdm import tqdm 21 from tqdm import tqdm
22 +from tqdm.contrib.concurrent import process_map
22 23
23 if path.isdir("/home/ubuntu/"): # this is the IFB-core cloud 24 if path.isdir("/home/ubuntu/"): # this is the IFB-core cloud
24 path_to_3D_data = "/mnt/Data/RNA/3D/" 25 path_to_3D_data = "/mnt/Data/RNA/3D/"
...@@ -478,7 +479,7 @@ def read_cpu_number(): ...@@ -478,7 +479,7 @@ def read_cpu_number():
478 # it reads info from /sys wich is not the VM resources but the host resources. 479 # it reads info from /sys wich is not the VM resources but the host resources.
479 # Read from /proc/cpuinfo instead. 480 # Read from /proc/cpuinfo instead.
480 p = subprocess.run(['grep', '-Ec', '(Intel|AMD)', '/proc/cpuinfo'], stdout=subprocess.PIPE) 481 p = subprocess.run(['grep', '-Ec', '(Intel|AMD)', '/proc/cpuinfo'], stdout=subprocess.PIPE)
481 - return int(p.stdout.decode('utf-8')[:-1])/2 482 + return int(int(p.stdout.decode('utf-8')[:-1])/2)
482 483
483 def warn(message, error=False): 484 def warn(message, error=False):
484 if error: 485 if error:
...@@ -510,7 +511,7 @@ def execute_job(j, jobcount): ...@@ -510,7 +511,7 @@ def execute_job(j, jobcount):
510 511
511 elif j.func_ is not None: 512 elif j.func_ is not None:
512 513
513 - #print(f"[{running_stats[0]+running_stats[2]}/{jobcount}]\t{j.func_.__name__}({', '.join([str(a) for a in j.args_ if not ((type(a) == list) and len(a)>3)])})") 514 + print(f"[{running_stats[0]+running_stats[2]}/{jobcount}]\t{j.func_.__name__}({', '.join([str(a) for a in j.args_ if not ((type(a) == list) and len(a)>3)])})")
514 515
515 m = -1 516 m = -1
516 monitor = Monitor(os.getpid()) 517 monitor = Monitor(os.getpid())
...@@ -771,15 +772,13 @@ def cm_realign(rfam_acc, chains, label): ...@@ -771,15 +772,13 @@ def cm_realign(rfam_acc, chains, label):
771 subprocess.run(["esl-reformat", "afa", path_to_seq_data + f"realigned/{rfam_acc}++.stk"], stdout=f) 772 subprocess.run(["esl-reformat", "afa", path_to_seq_data + f"realigned/{rfam_acc}++.stk"], stdout=f)
772 f.close() 773 f.close()
773 # subprocess.run(["rm", path_to_seq_data + f"realigned/{rfam_acc}.cm", path_to_seq_data + f"realigned/{rfam_acc}++.fa", path_to_seq_data + f"realigned/{rfam_acc}++.stk"]) 774 # subprocess.run(["rm", path_to_seq_data + f"realigned/{rfam_acc}.cm", path_to_seq_data + f"realigned/{rfam_acc}++.fa", path_to_seq_data + f"realigned/{rfam_acc}++.stk"])
774 -
775 -
776 else: 775 else:
777 # Ribosomal subunits deserve a special treatment. 776 # Ribosomal subunits deserve a special treatment.
778 # They require too much RAM to be aligned with Infernal. 777 # They require too much RAM to be aligned with Infernal.
779 # Then we will use SINA instead. 778 # Then we will use SINA instead.
780 779
781 # Get the seed alignment from Rfam 780 # Get the seed alignment from Rfam
782 - print(f"\t> Download latest LSU-Ref alignment from SILVA...", end="", flush=True) 781 + print(f"\t> Download latest LSU/SSU-Ref alignment from SILVA...", end="", flush=True)
783 if rfam_acc in ["RF02540", "RF02541", "RF02543"] and not path.isfile(path_to_seq_data + "realigned/LSU.arb"): 782 if rfam_acc in ["RF02540", "RF02541", "RF02543"] and not path.isfile(path_to_seq_data + "realigned/LSU.arb"):
784 try: 783 try:
785 _urlcleanup() 784 _urlcleanup()
...@@ -819,6 +818,7 @@ def cm_realign(rfam_acc, chains, label): ...@@ -819,6 +818,7 @@ def cm_realign(rfam_acc, chains, label):
819 "-o", path_to_seq_data + f"realigned/{rfam_acc}++.afa", 818 "-o", path_to_seq_data + f"realigned/{rfam_acc}++.afa",
820 "-r", path_to_seq_data + arbfile, 819 "-r", path_to_seq_data + arbfile,
821 "--meta-fmt=csv"]) 820 "--meta-fmt=csv"])
821 + return 0
822 822
823 def summarize_position(col): 823 def summarize_position(col):
824 # this function counts the number of nucleotides at a given position, given a "column" from a MSA. 824 # this function counts the number of nucleotides at a given position, given a "column" from a MSA.
...@@ -840,22 +840,21 @@ def summarize_position(col): ...@@ -840,22 +840,21 @@ def summarize_position(col):
840 else: 840 else:
841 return (0, 0, 0, 0, 0) 841 return (0, 0, 0, 0, 0)
842 842
843 -def alignment_nt_stats(f, list_of_chains) : 843 +def alignment_nt_stats(f) :
844 global idxQueue 844 global idxQueue
845 - #print("\t>",f,"... ", flush=True) 845 + list_of_chains = rfam_acc_to_download[f]
846 chains_ids = [ str(c) for c in list_of_chains ] 846 chains_ids = [ str(c) for c in list_of_chains ]
847 thr_idx = idxQueue.get() 847 thr_idx = idxQueue.get()
848 848
849 # Open the alignment 849 # Open the alignment
850 align = AlignIO.read(path_to_seq_data + f"realigned/{f}++.afa", "fasta") 850 align = AlignIO.read(path_to_seq_data + f"realigned/{f}++.afa", "fasta")
851 alilen = align.get_alignment_length() 851 alilen = align.get_alignment_length()
852 - #print("\t>",f,"... loaded", flush=True)
853 852
854 # Compute statistics per column 853 # Compute statistics per column
855 - pbar = tqdm(total=alilen, position=thr_idx, desc=f"Worker { thr_idx}: {f}", leave=False, ) 854 + pbar = tqdm(iterable=range(alilen), position=thr_idx+1, desc=f"Worker {thr_idx+1}: {f}", leave=False)
856 results = [ summarize_position(align[:,i]) for i in pbar ] 855 results = [ summarize_position(align[:,i]) for i in pbar ]
856 + pbar.close()
857 frequencies = np.array(results).T 857 frequencies = np.array(results).T
858 - #print("\t>",f,"... loaded, computed", flush=True)
859 858
860 for s in align: 859 for s in align:
861 if not '[' in s.id: # this is a Rfamseq entry, not PDB 860 if not '[' in s.id: # this is a Rfamseq entry, not PDB
...@@ -868,10 +867,10 @@ def alignment_nt_stats(f, list_of_chains) : ...@@ -868,10 +867,10 @@ def alignment_nt_stats(f, list_of_chains) :
868 # Save colums in the appropriate positions 867 # Save colums in the appropriate positions
869 i = 0 868 i = 0
870 j = 0 869 j = 0
871 - warn_gaps = False 870 + #warn_gaps = False
872 while i<c.full_length and j<alilen: 871 while i<c.full_length and j<alilen:
873 # here we try to map c.seq (the sequence of the 3D chain, including gaps when residues are missing), 872 # here we try to map c.seq (the sequence of the 3D chain, including gaps when residues are missing),
874 - # with s.seq, the sequence aligned in the MSA, containing any of ACGUacguP and two types of gaps, - and . 873 + # with s.seq, the sequence aligned in the MSA, containing any of ACGUacgu and two types of gaps, - and .
875 874
876 if c.seq[i] == s.seq[j].upper(): # alignment and sequence correspond (incl. gaps) 875 if c.seq[i] == s.seq[j].upper(): # alignment and sequence correspond (incl. gaps)
877 list_of_chains[idx].frequencies = np.concatenate((list_of_chains[idx].frequencies, frequencies[:,j].reshape(-1,1)), axis=1) 876 list_of_chains[idx].frequencies = np.concatenate((list_of_chains[idx].frequencies, frequencies[:,j].reshape(-1,1)), axis=1)
...@@ -899,7 +898,7 @@ def alignment_nt_stats(f, list_of_chains) : ...@@ -899,7 +898,7 @@ def alignment_nt_stats(f, list_of_chains) :
899 continue 898 continue
900 899
901 # else, just ignore the gap. 900 # else, just ignore the gap.
902 - warn_gaps = True 901 + #warn_gaps = True
903 list_of_chains[idx].frequencies = np.concatenate((list_of_chains[idx].frequencies, np.array([0.0,0.0,0.0,0.0,1.0]).reshape(-1,1)), axis=1) 902 list_of_chains[idx].frequencies = np.concatenate((list_of_chains[idx].frequencies, np.array([0.0,0.0,0.0,0.0,1.0]).reshape(-1,1)), axis=1)
904 i += 1 903 i += 1
905 elif s.seq[j] in ['.', '-']: # gap in the alignment, but not in the real chain 904 elif s.seq[j] in ['.', '-']: # gap in the alignment, but not in the real chain
...@@ -927,7 +926,7 @@ def alignment_nt_stats(f, list_of_chains) : ...@@ -927,7 +926,7 @@ def alignment_nt_stats(f, list_of_chains) :
927 926
928 # one-hot encoding of the actual sequence 927 # one-hot encoding of the actual sequence
929 if c.seq[i] in letters[:4]: 928 if c.seq[i] in letters[:4]:
930 - point[ letters[:4].index(c.seq[i]), i ] = 1 929 + point[ 1 + letters[:4].index(c.seq[i]), i ] = 1
931 else: 930 else:
932 point[5,i] = 1 931 point[5,i] = 1
933 932
...@@ -944,9 +943,8 @@ def alignment_nt_stats(f, list_of_chains) : ...@@ -944,9 +943,8 @@ def alignment_nt_stats(f, list_of_chains) :
944 line = [str(x) for x in list(point[i,:]) ] 943 line = [str(x) for x in list(point[i,:]) ]
945 file.write(','.join(line)+'\n') 944 file.write(','.join(line)+'\n')
946 file.close() 945 file.close()
947 - #print("\t\tWritten", c.chain_label, f"to file\t{validsymb}", flush=True)
948 946
949 - #print("\t>", f, f"... loaded, computed, saved\t{validsymb}", flush=True) 947 + idxQueue.put(thr_idx) # replace the thread index
950 return 0 948 return 0
951 949
952 if __name__ == "__main__": 950 if __name__ == "__main__":
...@@ -1054,15 +1052,25 @@ if __name__ == "__main__": ...@@ -1054,15 +1052,25 @@ if __name__ == "__main__":
1054 families = sorted([f for f in rfam_acc_to_download.keys() ]) 1052 families = sorted([f for f in rfam_acc_to_download.keys() ])
1055 1053
1056 # Build job list 1054 # Build job list
1057 - thr_idx_mgr = multiprocessing.Manager() 1055 + thr_idx_mgr = Manager()
1058 idxQueue = thr_idx_mgr.Queue() 1056 idxQueue = thr_idx_mgr.Queue()
1059 - for i in range(10): 1057 + for i in range(read_cpu_number()):
1060 idxQueue.put(i) 1058 idxQueue.put(i)
1061 - fulljoblist = [] 1059 +
1062 - for f in families: 1060 + # freeze_support()
1063 - label = f"Save {f} PSSMs" 1061 + # r = process_map(alignment_nt_stats, families, max_workers=read_cpu_number())
1064 - list_of_chains = rfam_acc_to_download[f] 1062 + p = Pool(initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),), processes=read_cpu_number())
1065 - fulljoblist.append(Job(function=alignment_nt_stats, args=[f, list_of_chains], how_many_in_parallel=10, priority=1, label=label)) 1063 + fam_pbar = tqdm(total=len(families), desc="RNA families", position=0, leave=True)
1066 - execute_joblist(fulljoblist, printstats=False) 1064 + for i, _ in enumerate(p.imap_unordered(alignment_nt_stats, families)):
1065 + fam_pbar.update(1)
1066 + fam_pbar.close()
1067 + p.close()
1068 + p.join()
1069 +
1070 + # ==========================================================================================
1071 + # Do a brief statistics summary of the produced dataset
1072 + # ==========================================================================================
1073 +
1074 + #TODO: compute nt frequencies, chain lengths, angle clusters
1067 1075
1068 print("Completed.") 1076 print("Completed.")
......