-
Notifications
You must be signed in to change notification settings - Fork 8
/
VIBFold_adapted_functions.py
343 lines (295 loc) · 13.2 KB
/
VIBFold_adapted_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import random
from tqdm import tqdm
import dataclasses
import json
import copy
from alphafold.data import feature_processing, msa_pairing, msa_identifiers
from alphafold.data import pipeline as alpha_pipeline
from alphafold.data import parsers
from alphafold.common import residue_constants
from alphafold.data import pipeline_multimer
import numpy as np
import tarfile
import requests
import time
import os
# Copied from AlphaFold - should recheck where this comes from
def process_multimer_features(features_for_chain):
all_chain_features = {}
for chain_id, chain_features in features_for_chain.items():
all_chain_features[chain_id] = pipeline_multimer.convert_monomer_features(
chain_features, chain_id
)
all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)
feature_processing.process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not feature_processing._is_homomer_or_monomer(np_chains_list)
chains = list(np_chains_list)
chain_keys = chains[0].keys()
updated_chains = []
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if "_all_seq" not in k}
for feature_name in chain_keys:
if feature_name.endswith("_all_seq"):
feats_padded = msa_pairing.pad_features(
chain[feature_name], feature_name
)
new_chain[feature_name] = feats_padded
new_chain["num_alignments_all_seq"] = np.asarray(
len(np_chains_list[chain_num]["msa_all_seq"])
)
updated_chains.append(new_chain)
np_chains_list = updated_chains
np_chains_list = feature_processing.crop_chains(
np_chains_list,
msa_crop_size=feature_processing.MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=feature_processing.MAX_TEMPLATES,
)
np_example = feature_processing.msa_pairing.merge_chain_features(
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences,
max_templates=feature_processing.MAX_TEMPLATES,
)
np_example = feature_processing.process_final(np_example)
np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)
return np_example
# Copied from AlphaFold, adapted code so that duplicate rows in the MSA are not deleted. This gave problems
# when pairing MSAs
def my_make_msa_features_keep_duplicates(msas):
if not msas:
raise ValueError('At least one MSA must be provided.')
int_msa = []
deletion_matrix = []
species_ids = []
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa.sequences):
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index])
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array([num_alignments] * num_res, dtype=np.int32)
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features
# Copied from AlphaFold datapipeline package: modified the process function to use a seq_to_features_cache across
# runs (multiple permutations), instead of within one run
class Cached_DataPipeline(pipeline_multimer.DataPipeline): # copied from AlphaFold repo, adapted caching
def __init__(self, monomer_data_pipeline, jackhmmer_binary_path, uniprot_database_path, use_precomputed_msas, seq_to_features_cache):
super(Cached_DataPipeline, self).__init__(monomer_data_pipeline=monomer_data_pipeline,
jackhmmer_binary_path=jackhmmer_binary_path,
uniprot_database_path=uniprot_database_path,
use_precomputed_msas=use_precomputed_msas)
self.seq_to_features_cache = seq_to_features_cache
def process(self,
input_fasta_path: str,
msa_output_dir: str) -> alpha_pipeline.FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
chain_id_map = pipeline_multimer._make_chain_id_map(sequences=input_seqs, descriptions=input_descs)
chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)
for chain_id, fasta_chain in chain_id_map.items()}
json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
all_chain_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for chain_id, fasta_chain in chain_id_map.items():
if fasta_chain.sequence in self.seq_to_features_cache:
all_chain_features[chain_id] = copy.deepcopy(
self.seq_to_features_cache[fasta_chain.sequence])
continue
chain_features = self._process_single_chain(
chain_id=chain_id,
sequence=fasta_chain.sequence,
description=fasta_chain.description,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=is_homomer_or_monomer)
chain_features = pipeline_multimer.convert_monomer_features(chain_features, chain_id=chain_id)
all_chain_features[chain_id] = chain_features
self.seq_to_features_cache[fasta_chain.sequence] = copy.deepcopy(chain_features)
all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(all_chain_features=all_chain_features)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
# need this to
return np_example
# Copied from AlphaFold repo, to use in the pipeline class
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
# Exact copy from ColabFold
##########################################
# call mmseqs2
##########################################
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
use_templates=False, filter=None, use_pairing=False,
host_url="https://api.colabfold.com"):
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
def submit(seqs, mode, N=101):
n, query = N, ""
for seq in seqs:
query += f">{n}\n{seq}\n"
n += 1
res = requests.post(f'{host_url}/{submission_endpoint}', data={'q':query,'mode': mode})
try:
out = res.json()
except ValueError:
print(f"Server didn't reply with json: {res.text}")
out = {"status":"ERROR"}
return out
def status(ID):
res = requests.get(f'{host_url}/ticket/{ID}')
try:
out = res.json()
except ValueError:
print(f"Server didn't reply with json: {res.text}")
out = {"status":"ERROR"}
return out
def download(ID, path):
res = requests.get(f'{host_url}/result/download/{ID}')
with open(path,"wb") as out: out.write(res.content)
# process input x
seqs = [x] if isinstance(x, str) else x
# compatibility to old option
if filter is not None:
use_filter = filter
# setup mode
if use_filter:
mode = "env" if use_env else "all"
else:
mode = "env-nofilter" if use_env else "nofilter"
if use_pairing:
mode = ""
use_templates = False
use_env = False
# define path
path = f"{prefix}_{mode}"
if not os.path.isdir(path): os.mkdir(path)
# call mmseqs2 api
tar_gz_file = f'{path}/out.tar.gz'
N,REDO = 101,True
# deduplicate and keep track of order
seqs_unique = []
#TODO this might be slow for large sets
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
Ms = [N + seqs_unique.index(seq) for seq in seqs]
# lets do it!
if not os.path.isfile(tar_gz_file):
TIME_ESTIMATE = 150 * len(seqs_unique)
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
while REDO:
pbar.set_description("SUBMIT")
# Resubmit job until it goes through
out = submit(seqs_unique, mode, N)
while out["status"] in ["UNKNOWN", "RATELIMIT"]:
sleep_time = 5 + random.randint(0, 5)
print(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
# resubmit
time.sleep(sleep_time)
out = submit(seqs_unique, mode, N)
if out["status"] == "ERROR":
raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
if out["status"] == "MAINTENANCE":
raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')
# wait for job to finish
ID,TIME = out["id"],0
pbar.set_description(out["status"])
while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
t = 5 + random.randint(0,5)
print(f"Sleeping for {t}s. Reason: {out['status']}")
time.sleep(t)
out = status(ID)
pbar.set_description(out["status"])
if out["status"] == "RUNNING":
TIME += t
pbar.update(n=t)
#if TIME > 900 and out["status"] != "COMPLETE":
# # something failed on the server side, need to resubmit
# N += 1
# break
if out["status"] == "COMPLETE":
if TIME < TIME_ESTIMATE:
pbar.update(n=(TIME_ESTIMATE-TIME))
REDO = False
if out["status"] == "ERROR":
REDO = False
raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
# Download results
download(ID, tar_gz_file)
# prep list of a3m files
if use_pairing:
a3m_files = [f"{path}/pair.a3m"]
else:
a3m_files = [f"{path}/uniref.a3m"]
if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
# extract a3m files
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
with tarfile.open(tar_gz_file) as tar_gz:
tar_gz.extractall(path)
# templates
if use_templates:
templates = {}
#print("seq\tpdb\tcid\tevalue")
for line in open(f"{path}/pdb70.m8","r"):
p = line.rstrip().split()
M,pdb,qid,e_value = p[0],p[1],p[2],p[10]
M = int(M)
if M not in templates: templates[M] = []
templates[M].append(pdb)
#if len(templates[M]) <= 20:
# print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}")
template_paths = {}
for k,TMPL in templates.items():
TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
if not os.path.isdir(TMPL_PATH):
os.mkdir(TMPL_PATH)
TMPL_LINE = ",".join(TMPL[:20])
os.system(f"curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/")
os.system(f"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex")
os.system(f"touch {TMPL_PATH}/pdb70_cs219.ffdata")
template_paths[k] = TMPL_PATH
# gather a3m lines
a3m_lines = {}
for a3m_file in a3m_files:
update_M,M = True,None
for line in open(a3m_file,"r"):
if len(line) > 0:
if "\x00" in line:
line = line.replace("\x00","")
update_M = True
if line.startswith(">") and update_M:
M = int(line[1:].rstrip())
update_M = False
if M not in a3m_lines: a3m_lines[M] = []
a3m_lines[M].append(line)
# return results
a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
if use_templates:
template_paths_ = []
for n in Ms:
if n not in template_paths:
template_paths_.append(None)
#print(f"{n-N}\tno_templates_found")
else:
template_paths_.append(template_paths[n])
template_paths = template_paths_
return (a3m_lines, template_paths) if use_templates else a3m_lines