Louis BECQUEY

ON CONFLICT clauses for SQL updates

......@@ -127,7 +127,7 @@ class Chain:
self.pdb_chain_id = pdb_chain_id # chain ID (mmCIF), multiple letters
self.pdb_start = pdb_start # if portion of chain, the start number (relative to the chain, not residue numbers)
self.pdb_end = pdb_end # if portion of chain, the start number (relative to the chain, not residue numbers)
self.reversed = (pdb_start > pdb_end) # wether pdb_start > pdb_end in the Rfam mapping
self.reversed = (pdb_start > pdb_end) if pdb_start is not None else False # wether pdb_start > pdb_end in the Rfam mapping
self.chain_label = chain_label # chain pretty name
self.file = "" # path to the 3D PDB file
self.rfam_fam = rfam # mapping to an RNA family
......@@ -257,7 +257,8 @@ class Chain:
df = df.drop_duplicates("index_chain", keep="first") # drop doublons in index_chain
while (len(df.index_chain) and df.iloc[[-1]].nt_name.tolist()[0] not in ["A", "C", "G", "U"] and
((df.iloc[[-1]][["alpha", "beta", "gamma", "delta", "epsilon", "zeta", "v0", "v1", "v2", "v3", "v4"]].isna().values).all()
or (df.iloc[[-1]].puckering=='').any())):
or (df.iloc[[-1]].puckering=='').any())
or (len(df.index_chain) >= 2 and df.iloc[[-1]].nt_resnum.iloc[0] > 50 + df.iloc[[-2]].nt_resnum.iloc[0])):
df = df.head(-1)
# Assert some nucleotides exist
......@@ -413,20 +414,22 @@ class Chain:
with sqlite3.connect(runDir+"/results/RNANet.db", timeout=10.0) as conn:
# Register the chain in table chain
if self.pdb_start is not None:
sql_execute(conn, f""" INSERT OR REPLACE INTO chain
sql_execute(conn, f""" INSERT INTO chain
(structure_id, chain_name, pdb_start, pdb_end, reversed, rfam_acc, inferred, issue)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?);""",
(?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(structure_id, chain_name, rfam_acc) DO
UPDATE SET pdb_start=excluded.pdb_start, pdb_end=excluded.pdb_end, reversed=excluded.reversed, inferred=excluded.inferred, issue=excluded.issue;""",
data=(str(self.pdb_id), str(self.pdb_chain_id), int(self.pdb_start), int(self.pdb_end), int(self.reversed), str(self.rfam_fam), int(self.inferred), int(self.delete_me)))
# get the chain id
self.db_chain_id = sql_ask_database(conn, f"SELECT (chain_id) FROM chain WHERE structure_id='{self.pdb_id}' AND chain_name='{self.pdb_chain_id}' AND rfam_acc='{self.rfam_fam}';")[0][0]
else:
sql_execute(conn, "INSERT OR REPLACE INTO chain (structure_id, chain_name, issue) VALUES (?, ?, ?);", data=(str(self.pdb_id), int(self.pdb_chain_id), int(self.delete_me)))
sql_execute(conn, "INSERT INTO chain (structure_id, chain_name, issue) VALUES (?, ?, ?) ON CONFLICT(structure_id, chain_name) DO UPDATE SET issue=excluded.issue;", data=(str(self.pdb_id), int(self.pdb_chain_id), int(self.delete_me)))
self.db_chain_id = sql_ask_database(conn, f"SELECT (chain_id) FROM chain WHERE structure_id='{self.pdb_id}' AND chain_name='{self.pdb_chain_id}' AND rfam_acc IS NULL;")[0][0]
# Add the nucleotides
sql_execute(conn, f"""
INSERT OR REPLACE INTO nucleotide
INSERT OR IGNORE INTO nucleotide
(chain_id, index_chain, nt_resnum, nt_name, nt_code, dbn, alpha, beta, gamma, delta, epsilon, zeta,
epsilon_zeta, bb_type, chi, glyco_bond, form, ssZp, Dp, eta, theta, eta_prime, theta_prime, eta_base, theta_base,
v0, v1, v2, v3, v4, amplitude, phase_angle, puckering, nt_align_code, is_A, is_C, is_G, is_U, is_other, nt_position,
......@@ -1251,7 +1254,7 @@ class Pipeline:
r = sql_ask_database(conn, """SELECT COUNT(chain.chain_id) as Count, rfam_acc
FROM chain LEFT JOIN re_mapping
ON chain.chain_id = re_mapping.chain_id
WHERE remapping_id IS NULL GROUP BY rfam_acc;""")
WHERE index_ali IS NULL GROUP BY rfam_acc;""")
if len(r) and r[0][0] is not None:
warn("Structures were not remapped:")
for x in r:
......@@ -1263,7 +1266,7 @@ class Pipeline:
NATURAL JOIN re_mapping
LEFT JOIN align_column
ON re_mapping.index_ali=align_column.index_ali AND c.rfam_acc=align_column.rfam_acc
WHERE column_id IS NULL;""")
WHERE freq_A IS NULL;""")
if len(r) and r[0][0] is not None:
warn("Structures were not remapped:")
for x in r:
......@@ -1363,7 +1366,6 @@ def sql_define_tables(conn):
FOREIGN KEY(rfam_acc) REFERENCES family(rfam_acc)
);
CREATE TABLE IF NOT EXISTS nucleotide (
nt_id INTEGER PRIMARY KEY NOT NULL,
chain_id INT,
index_chain SMALLINT,
nt_resnum SMALLINT,
......@@ -1390,15 +1392,14 @@ def sql_define_tables(conn):
phase_angle REAL,
amplitude REAL,
puckering VARCHAR(20),
UNIQUE (chain_id, index_chain),
PRIMARY KEY (chain_id, index_chain),
FOREIGN KEY(chain_id) REFERENCES chain(chain_id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS re_mapping (
remapping_id INTEGER PRIMARY KEY NOT NULL,
chain_id INT NOT NULL,
index_chain INT NOT NULL,
index_ali INT NOT NULL,
UNIQUE (chain_id, index_chain),
PRIMARY KEY (chain_id, index_chain),
FOREIGN KEY(chain_id) REFERENCES chain(chain_id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS family (
......@@ -1413,7 +1414,6 @@ def sql_define_tables(conn):
idty_percent REAL
);
CREATE TABLE IF NOT EXISTS align_column (
column_id INTEGER PRIMARY KEY NOT NULL,
rfam_acc CHAR(7) NOT NULL,
index_ali INT NOT NULL,
freq_A REAL,
......@@ -1421,14 +1421,14 @@ def sql_define_tables(conn):
freq_G REAL,
freq_U REAL,
freq_other REAL,
UNIQUE (rfam_acc, index_ali),
PRIMARY KEY (rfam_acc, index_ali),
FOREIGN KEY(rfam_acc) REFERENCES family(rfam_acc)
);
""")
conn.commit()
@trace_unhandled_exceptions
def sql_ask_database(conn, sql, warn_every = 0):
def sql_ask_database(conn, sql, warn_every = 10):
"""
Reads the SQLite database.
Returns a list of tuples.
......@@ -1447,7 +1447,7 @@ def sql_ask_database(conn, sql, warn_every = 0):
return []
@trace_unhandled_exceptions
def sql_execute(conn, sql, many=False, data=None, warn_every=0):
def sql_execute(conn, sql, many=False, data=None, warn_every=10):
for _ in range(100): # retry 100 times if it fails
try:
if many:
......@@ -1772,7 +1772,7 @@ def work_build_chain(c, extract, khetatm, retrying=False):
# Small check
if not c.delete_me:
with sqlite3.connect(runDir+"/results/RNANet.db", timeout=10.0) as conn:
nnts = sql_ask_database(conn, f"SELECT COUNT(nt_id) FROM nucleotide WHERE chain_id={c.db_chain_id};", warn_every=10)[0][0]
nnts = sql_ask_database(conn, f"SELECT COUNT(index_chain) FROM nucleotide WHERE chain_id={c.db_chain_id};", warn_every=10)[0][0]
if not(nnts):
warn(f"Nucleotides not inserted: {c.error_messages}")
c.delete_me = True
......@@ -2016,19 +2016,22 @@ def work_pssm(f, fill_gaps):
pbar.update(1)
pbar.close()
# Check we found something
if not len(re_mappings):
warn(f"Chains were not found in {f}++.afa file: {chains_ids}", error=True)
return 1
# Save the re_mappings
conn = sqlite3.connect(runDir + '/results/RNANet.db', timeout=20.0)
sql_execute(conn, "INSERT OR REPLACE INTO re_mapping (chain_id, index_chain, index_ali) VALUES (?, ?, ?);", many=True, data=re_mappings)
sql_execute(conn, "INSERT INTO re_mapping (chain_id, index_chain, index_ali) VALUES (?, ?, ?) ON CONFLICT(chain_id, index_chain) DO UPDATE SET index_ali=excluded.index_ali;", many=True, data=re_mappings)
# Save the useful columns in the database
try:
data = [ (f, j) + frequencies[j-1] for j in sorted(columns_to_save) ]
except:
print(f, align.get_alignment_length(), len(frequencies), columns_to_save)
sql_execute(conn, """INSERT OR REPLACE INTO align_column (rfam_acc, index_ali, freq_A, freq_C, freq_G, freq_U, freq_other)
VALUES (?, ?, ?, ?, ?, ?, ?);""", many=True, data=data)
data = [ (f, j) + frequencies[j-1] for j in sorted(columns_to_save) ]
sql_execute(conn, """INSERT INTO align_column (rfam_acc, index_ali, freq_A, freq_C, freq_G, freq_U, freq_other)
VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(rfam_acc, index_ali) DO
UPDATE SET freq_A=excluded.freq_A, freq_C=excluded.freq_C, freq_G=excluded.freq_G, freq_U=excluded.freq_U, freq_other=excluded.freq_other;""", many=True, data=data)
# Add an unknown values column, with index_ali 0
sql_execute(conn, f"""INSERT OR REPLACE INTO align_column (rfam_acc, index_ali, freq_A, freq_C, freq_G, freq_U, freq_other)
sql_execute(conn, f"""INSERT OR IGNORE INTO align_column (rfam_acc, index_ali, freq_A, freq_C, freq_G, freq_U, freq_other)
VALUES (?, 0, 0.0, 0.0, 0.0, 0.0, 1.0);""", data=(f,))
# Replace gaps by consensus
......@@ -2039,11 +2042,10 @@ def work_pssm(f, fill_gaps):
try:
idx = chains_ids.index(s.id)
list_of_chains[idx].replace_gaps(conn)
except ValueError:
pass # We already printed a warning just above
list_of_chains[idx].replace_gaps(conn)
conn.close()
idxQueue.put(thr_idx) # replace the thread index in the queue
return 0
......@@ -2094,38 +2096,38 @@ if __name__ == "__main__":
sql_define_tables(conn)
print("> Storing results into", runDir + "/results/RNANet.db")
# compute an update compared to what is in the table "chain"
#DEBUG: list everything
pp.REUSE_ALL = True
pp.list_available_mappings()
# # compute an update compared to what is in the table "chain"
# #DEBUG: list everything
# pp.REUSE_ALL = True
# pp.list_available_mappings()
# # ===========================================================================
# # 3D information
# # ===========================================================================
# # Download and annotate new RNA 3D chains (Chain objects in pp.update)
# pp.dl_and_annotate(coeff_ncores=0.75)
# # At this point, the structure table is up to date
# pp.build_chains(coeff_ncores=2.0)
# if len(pp.to_retry):
# # Redownload and re-annotate
# print("> Retrying to annotate some structures which just failed.", flush=True)
# pp.dl_and_annotate(retry=True, coeff_ncores=0.5) #
# pp.build_chains(retry=True, coeff_ncores=1.0) # Use half the cores to reduce required amount of memory
# print(f"> Loaded {len(pp.loaded_chains)} RNA chains ({len(pp.update) - len(pp.loaded_chains)} errors).")
# pp.checkpoint_save_chains()
# if not pp.HOMOLOGY:
# # Save chains to file
# for c in pp.loaded_chains:
# work_save(c, homology=False)
# print("Completed.")
# exit()
# ===========================================================================
# 3D information
# ===========================================================================
# Download and annotate new RNA 3D chains (Chain objects in pp.update)
pp.dl_and_annotate(coeff_ncores=0.75)
# At this point, the structure table is up to date
pp.build_chains()
if len(pp.to_retry):
# Redownload and re-annotate
print("> Retrying to annotate some structures which just failed.", flush=True)
pp.dl_and_annotate(retry=True, coeff_ncores=0.5) #
pp.build_chains(retry=True, coeff_ncores=0.5) # Use half the cores to reduce required amount of memory
print(f"> Loaded {len(pp.loaded_chains)} RNA chains ({len(pp.update) - len(pp.loaded_chains)} errors).")
pp.checkpoint_save_chains()
if not pp.HOMOLOGY:
# Save chains to file
for c in pp.loaded_chains:
work_save(c, homology=False)
print("Completed.")
exit()
# At this point, structure, chain and nucleotide tables of the database are up to date.
# (Modulo some statistics computed by statistics.py)
# # At this point, structure, chain and nucleotide tables of the database are up to date.
# # (Modulo some statistics computed by statistics.py)
# ===========================================================================
# Homology information
......@@ -2143,8 +2145,8 @@ if __name__ == "__main__":
print(f"> Identified {len(rfam_acc_to_download.keys())} families to update and re-align with the crystals' sequences:")
pp.fam_list = sorted(rfam_acc_to_download.keys())
# pp.prepare_sequences()
# pp.realign()
pp.prepare_sequences()
pp.realign()
# At this point, the family table is up to date
......@@ -2152,7 +2154,7 @@ if __name__ == "__main__":
idxQueue = thr_idx_mgr.Queue()
pp.remap()
# At this point, the align_column and re_mapping tables are up-to-date.
# ==========================================================================================
......
......@@ -175,7 +175,7 @@ def stats_len():
cols.append("orange")
else:
cols.append("grey")
l = [ x[0] for x in sql_ask_database(conn, f"SELECT COUNT(nt_id) FROM (SELECT chain_id FROM chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY chain_id;") ]
l = [ x[0] for x in sql_ask_database(conn, f"SELECT COUNT(index_chain) FROM (SELECT chain_id FROM chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide GROUP BY chain_id;") ]
lengths.append(l)
notify(f"[{i+1}/{len(fam_list)}] Computed {f} chains lengths")
conn.close()
......@@ -245,32 +245,83 @@ def parallel_stats_pairs(f):
REQUIRES tables chain, nucleotide up-to-date."""
with sqlite3.connect("results/RNANet.db") as conn:
# Get comma separated lists of basepairs per nucleotide
interactions = pd.read_sql(f"SELECT paired, pair_type_LW FROM (SELECT chain_id FROM chain WHERE rfam_acc='{f}') NATURAL JOIN nucleotide WHERE nb_interact>0;", conn)
# expand the comma-separated lists in real lists
expanded_list = pd.concat([ pd.DataFrame({ 'paired':row['paired'].split(','), 'pair_type_LW':row['pair_type_LW'].split(',') })
for _, row in interactions.iterrows() ]).reset_index(drop=True)
# keep only intra-chain interactions
expanded_list = expanded_list[ expanded_list.paired != '0' ].pair_type_LW
chain_id_list = mappings_list[f]
data = []
for cid in chain_id_list:
with sqlite3.connect("results/RNANet.db") as conn:
# Get comma separated lists of basepairs per nucleotide
interactions = pd.read_sql(f"SELECT nt_code as nt1, index_chain, paired, pair_type_LW FROM (SELECT chain_id FROM chain WHERE chain_id='{cid}') NATURAL JOIN nucleotide;", conn)
# expand the comma-separated lists in real lists
expanded_list = pd.concat([ pd.DataFrame({ 'nt1':[ row["nt1"] for x in row["paired"].split(',') ],
'index_chain':[ row['index_chain'] for x in row["paired"].split(',') ],
'paired':row['paired'].split(','),
'pair_type_LW':row['pair_type_LW'].split(',')
})
for _, row in interactions.iterrows()
]).reset_index(drop=True)
# Add second nucleotide
nt2 = []
for _, row in expanded_list.iterrows():
if row.paired in ['', '0']:
nt2.append('')
else:
try:
n = expanded_list[expanded_list.index_chain == int(row.paired)].nt1.tolist()[0]
nt2.append(n)
except IndexError:
print(cid, flush=True)
try:
expanded_list["nt2"] = nt2
except ValueError:
print(cid, flush=True)
print(expanded_list, flush=True)
return 0,0
# keep only intra-chain interactions
expanded_list = expanded_list[ ~expanded_list.paired.isin(['0','']) ]
expanded_list["nts"] = expanded_list["nt1"] + expanded_list["nt2"]
# Get basepair type
expanded_list["basepair"] = np.where(expanded_list.nts.isin(["AU","UA"]), "AU",
np.where(expanded_list.nts.isin(["GC","CG"]), "GC",
np.where(expanded_list.nts.isin(["GU","UG"]), "Wobble","Other")
)
)
# checks
# ct = pd.crosstab(expanded_list.pair_type_LW, expanded_list.basepair)
# ct = ct.loc[[ x for x in ["cWW","cHH","cSS","tWW","tHH","tSS"] if x in ct.index ]]
# for _, symmetric_type in ct.iterrows():
# for x in symmetric_type:
# if x%2:
# print("Odd number found for", symmetric_type.name, "in chain", cid, flush=True)
# print(expanded_list, flush=True)
# exit()
expanded_list = expanded_list[["basepair", "pair_type_LW"]]
data.append(expanded_list)
# merge all the dataframes from all chains of the family
expanded_list = pd.concat(data)
# Count each pair type
vcnts = expanded_list.value_counts()
vcnts = expanded_list.pair_type_LW.value_counts()
# Add these new counts to the family's counter
cnt = Counter()
cnt.update(dict(vcnts))
# Create an output DataFrame
return pd.DataFrame([[ x for x in cnt.values() ]], columns=list(cnt), index=[f])
f_df = pd.DataFrame([[ x for x in cnt.values() ]], columns=list(cnt), index=[f])
return expanded_list, f_df
def stats_pairs():
"""Counts occurrences of intra-chain base-pair types in RNA families
Creates a temporary results file in data/pair_counts.csv, and a results file in results/pairings.csv.
REQUIRES tables chain, nucleotide up-to-date."""
def line_format(family_data):
return family_data.apply(partial(format_percentage, sum(family_data)))
......@@ -279,9 +330,23 @@ def stats_pairs():
try:
fam_pbar = tqdm(total=len(fam_list), desc="Pair-types in families", position=0, leave=True)
results = []
for i, fam_df in enumerate(p.imap_unordered(parallel_stats_pairs, fam_list)):
allpairs = []
for i, _ in enumerate(p.imap_unordered(parallel_stats_pairs, fam_list)):
newpairs, fam_df = _
fam_pbar.update(1)
results.append(fam_df)
allpairs.append(newpairs)
# Checks
vlcnts= newpairs.pair_type_LW.value_counts()
identical = [fam_df[i][0] == newpairs.pair_type_LW.value_counts().at[i] for i in fam_df.columns]
if False in identical:
print(fam_df)
print(vlcnts)
print("Dataframes differ for",fam_df.index[0], flush=True)
for x in ["cWW","cHH","cSS","tWW","tHH","tSS"]:
if x in vlcnts.index and vlcnts[x] % 2:
print("Trouvé un nombre impair de",x,"dans",fam_df.index[0], flush=True)
fam_pbar.close()
p.close()
p.join()
......@@ -292,24 +357,36 @@ def stats_pairs():
p.join()
exit(1)
all_pairs = pd.concat(allpairs)
df = pd.concat(results).fillna(0)
vlcnts= all_pairs.pair_type_LW.value_counts()
for x in ["cWW","cHH","cSS","tWW","tHH","tSS"]:
if x in vlcnts.index and vlcnts[x] % 2:
print("Trouvé un nombre impair de",x,"après le merge !", flush=True)
df.to_csv("data/pair_counts.csv")
all_pairs.to_csv("data/all_pairs.csv")
else:
df = pd.read_csv("data/pair_counts.csv", index_col=0)
all_pairs = pd.read_csv("data/all_pairs.csv", index_col=0)
print(df)
# Remove not very well defined pair types (not in the 12 LW types)
crosstab = pd.crosstab(all_pairs.pair_type_LW, all_pairs.basepair)
col_list = [ x for x in df.columns if '.' in x ]
# Remove not very well defined pair types (not in the 12 LW types)
df['other'] = df[col_list].sum(axis=1)
df.drop(col_list, axis=1, inplace=True)
print(df)
crosstab = crosstab.append(crosstab.loc[col_list].sum(axis=0).rename("Other"))
# drop duplicate types
# The twelve Leontis-Westhof types are
# cWW cWH cWS cHH cHS cSS (do not count cHW cSW and cSH, they are the same as their opposites)
# tWW tWH tWS tHH tHS tSS (do not count tHW tSW and tSH, they are the same as their opposites)
df.drop([ "cHW", "tHW", "cSW", "tSW", "cHS", "tHS"], axis=1)
df.loc[ ["cWW", "tWW", "cHH", "tHH", "cSS", "tSS", "other"] ] /= 2.0
df.drop([ x for x in [ "cHW", "tHW", "cSW", "tSW", "cHS", "tHS"] if x in df.columns], axis=1)
crosstab = crosstab.loc[[ x for x in ["cWW","cWH","cWS","cHH","cHS","cSS","tWW","tWH","tWS","tHH","tHS","tSS","Other"] if x in crosstab.index]]
df.loc[:,[x for x in ["cWW", "tWW", "cHH", "tHH", "cSS", "tSS", "other"] if x in df.columns] ] /= 2
# crosstab.loc[["cWW", "tWW", "cHH", "tHH", "cSS", "tSS", "Other"]] /= 2
print(crosstab)
print(df)
# Compute total row
total_series = df.sum(numeric_only=True).rename("TOTAL")
......@@ -326,7 +403,6 @@ def stats_pairs():
# Plot barplot of overall types
total_series.sort_values(ascending=False, inplace=True)
total_series.apply(lambda x: x/2.0) # each interaction was counted twice because one time per extremity
ax = total_series.plot(figsize=(5,3), kind='bar', log=True, ylim=(1e4,5000000) )
ax.set_ylabel("Number of observations")
plt.subplots_adjust(bottom=0.2, right=0.99)
......@@ -445,11 +521,11 @@ if __name__ == "__main__":
# Define threads for the tasks
threads = [
# th.Thread(target=reproduce_wadley_results, kwargs={'carbon': 1}),
# th.Thread(target=reproduce_wadley_results, kwargs={'carbon': 4}),
# th.Thread(target=stats_len),
# th.Thread(target=stats_freq),
# th.Thread(target=seq_idty),
th.Thread(target=reproduce_wadley_results, kwargs={'carbon': 1}),
th.Thread(target=reproduce_wadley_results, kwargs={'carbon': 4}),
th.Thread(target=stats_len),
th.Thread(target=stats_freq),
th.Thread(target=seq_idty),
th.Thread(target=per_chain_stats)
]
......