forked from HannesStark/EquiBind
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multiligand_inference.py
executable file
·326 lines (273 loc) · 15.5 KB
/
multiligand_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
#!/usr/bin/env python
import argparse
import sys
from copy import deepcopy
import os
from rdkit import Chem
from rdkit.Geometry import Point3D
from commons.geometry_utils import rigid_transform_Kabsch_3D, get_torsions, get_dihedral_vonMises, apply_changes
from commons.process_mols import get_rec_graph, get_receptor_inference
#from train import load_model
from commons.utils import seed_all
import yaml
from datasets.custom_collate import * # do not remove
from models import * # do not remove
from torch.nn import * # do not remove
from torch.optim import * # do not remove
from commons.losses import * # do not remove
from torch.optim.lr_scheduler import * # do not remove
from torch.utils.data import DataLoader
# turn on for debugging C code like Segmentation Faults
import faulthandler
from datasets import multiple_ligands
faulthandler.enable()
from models.equibind import EquiBind
def parse_arguments(arglist = None):
p = argparse.ArgumentParser()
p.add_argument('--config', type=argparse.FileType(mode='r'), default=None)
p.add_argument('--checkpoint', '--model', dest = "checkpoint",
type=str, help='path to .pt file containing the model used for inference. '
'Defaults to runs/flexible_self_docking/best_checkpoint.pt in the same directory as the file being run')
p.add_argument('--train_args', type = str, help = "Path to a yaml file containing the parameters that were used to train the model. "
"If not supplied, it is assumed that a file named 'train_arguments.yaml' is located in the same directory as the model checkpoint")
p.add_argument('-o', '--output_directory', type=str, default=None, help='path where to put the predicted results')
p.add_argument('--batch_size', type=int, default=8, help='samples that will be processed in parallel')
p.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
p.add_argument('--device', type=str, default=None, help='What device to train on: cuda or cpu')
p.add_argument('--num_confs', type=int, default=1, help='num_confs if using rdkit conformers')
p.add_argument('--use_rdkit_coords', action="store_true", default = None, help='override the rkdit usage behavior of the used model')
p.add_argument('--no_use_rdkit_coords', dest = "use_rdkit_coords",action="store_false", help='override the rkdit usage behavior of the used model')
p.add_argument('--addH', action = "store_true", default = None, help = "Add hydrogens to the ligands before running EquiBind. If this flag is not enabled, every ligand should have hydrogens already.")
p.add_argument('--no_addH', dest = "addH", action = "store_false", help = "Add hydrogens to the ligands before running EquiBind. If this flag is not enabled, every ligand should have hydrogens already.")
p.add_argument('--no_skip', dest = "skip_in_output", action = "store_false", help = 'skip input files that already have corresponding folders in the output directory. Used to resume a large interrupted computation')
p.add_argument("--no_run_corrections", dest = "run_corrections", action = "store_false", help = "possibility of turning off running fast point cloud ligand fitting")
p.add_argument("-l", "--ligands_sdf", type=str, help = "A single sdf file containing all ligands to be screened when running in screening mode")
p.add_argument("-r", "--rec_pdb", type = str, help = "The receptor to dock the ligands in --ligands_sdf against")
p.add_argument("--n_workers_data_load", type = int, default = 4, help = "The number of cores used for loading the ligands and generating the graphs used as input to the model")
p.add_argument("--lig_name", default = None, help = "Choose the name of a property from the mol file to interpret as the name of the molecule")
# p.add_argument("--lig_slice", help = "Run only a slice of the provided ligand file. Like in python, this slice is HALF-OPEN. Should be provided in the format --lig_slice start,end")
p.add_argument("--lazy_dataload", dest = "lazy_dataload", action="store_true", default = None, help = "Turns on lazy dataloading. If on, will postpone rdkit parsing of each ligand until it is requested.")
# p.add_argument("--no_lazy_dataload", dest = "lazy_dataload", action="store_false", default = None, help = "Turns off lazy dataloading. If on, will postpone rdkit parsing of each ligand until it is requested.")
cmdline_parser = deepcopy(p)
args = p.parse_args(arglist)
clear_defaults = {key: argparse.SUPPRESS for key in args.__dict__}
cmdline_parser.set_defaults(**clear_defaults)
cmdline_parser._defaults = {}
cmdline_args = cmdline_parser.parse_args(arglist)
return p.parse_args(arglist), set(cmdline_args.__dict__.keys())
def get_default_args(args, cmdline_args):
if args.config:
config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
arg_dict = args.__dict__
for key, value in config_dict.items():
if isinstance(value, list):
for v in value:
arg_dict[key].append(v)
else:
arg_dict[key] = value
args.config = args.config.name
else:
config_dict = {}
if args.checkpoint is None:
args.checkpoint = os.path.join(os.path.dirname(__file__), "runs/flexible_self_docking/best_checkpoint.pt")
if args.device is None:
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
config_dict['checkpoint'] = args.checkpoint
# overwrite args with args from checkpoint except for the args that were contained in the config file or provided directly in the commandline
arg_dict = args.__dict__
if args.train_args is None:
with open(os.path.join(os.path.dirname(args.checkpoint), 'train_arguments.yaml'), 'r') as arg_file:
checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader)
else:
with open(args.train_args, 'r') as arg_file:
checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader)
for key, value in checkpoint_dict.items():
if (key not in config_dict.keys()) and (key not in cmdline_args) and not key == "device":
if isinstance(value, list) and key in arg_dict:
for v in value:
arg_dict[key].append(v)
else:
arg_dict[key] = value
args.model_parameters['noise_initial'] = 0
return args
def load_rec(args):
dp = args.dataset_params
rec_path = args.rec_pdb
rec, rec_coords, c_alpha_coords, n_coords, c_coords = get_receptor_inference(rec_path)
rec_graph = get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords,
use_rec_atoms=dp['use_rec_atoms'], rec_radius=dp['rec_graph_radius'],
surface_max_neighbors=dp['surface_max_neighbors'],
surface_graph_cutoff=dp['surface_graph_cutoff'],
surface_mesh_cutoff=dp['surface_mesh_cutoff'],
c_alpha_max_neighbors=dp['c_alpha_max_neighbors'])
return rec_graph
def load_model(args):
device = args.device
print(f"device = {device}")
checkpoint = torch.load(args.checkpoint, map_location=device)
model = EquiBind(device = device, lig_input_edge_feats_dim = 15, rec_input_edge_feats_dim = 27, **args.model_parameters)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
def load_rec_and_model(args):
return load_rec(args), load_model(args)
def run_batch(model, ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices):
try:
predictions = model(lig_graphs, rec_graphs, geometry_graphs)[0]
out_ligs = ligs
out_lig_coords = lig_coords
names = [multiple_ligands.safe_get_name(lig) for lig in ligs]
successes = list(zip(true_indices, names))
failures = []
except AssertionError:
lig_graphs, rec_graphs, geometry_graphs = (dgl.unbatch(lig_graphs),
dgl.unbatch(rec_graphs), dgl.unbatch(geometry_graphs))
predictions = []
out_ligs = []
out_lig_coords = []
successes = []
failures = []
for lig, lig_coord, lig_graph, rec_graph, geometry_graph, true_index in zip(ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices):
try:
output = model(lig_graph, rec_graph, geometry_graph)
except AssertionError as e:
failures.append((true_index, multiple_ligands.safe_get_name(lig)))
print(f"Failed for {multiple_ligands.safe_get_name(lig)}")
else:
out_ligs.append(lig)
out_lig_coords.append(lig_coord)
predictions.append(output[0][0])
successes.append((true_index, multiple_ligands.safe_get_name(lig)))
assert len(predictions) == len(out_ligs)
return out_ligs, out_lig_coords, predictions, successes, failures
def run_corrections(lig, lig_coord, ligs_coords_pred_untuned):
input_coords = lig_coord.detach().cpu()
prediction = ligs_coords_pred_untuned.detach().cpu()
lig_input = deepcopy(lig)
conf = lig_input.GetConformer()
for i in range(lig_input.GetNumAtoms()):
x, y, z = input_coords.numpy()[i]
conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
lig_equibind = deepcopy(lig)
conf = lig_equibind.GetConformer()
for i in range(lig_equibind.GetNumAtoms()):
x, y, z = prediction.numpy()[i]
conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
coords_pred = lig_equibind.GetConformer().GetPositions()
Z_pt_cloud = coords_pred
rotable_bonds = get_torsions([lig_input])
new_dihedrals = np.zeros(len(rotable_bonds))
for idx, r in enumerate(rotable_bonds):
new_dihedrals[idx] = get_dihedral_vonMises(lig_input, lig_input.GetConformer(), r, Z_pt_cloud)
optimized_mol = apply_changes(lig_input, new_dihedrals, rotable_bonds)
optimized_conf = optimized_mol.GetConformer()
coords_pred_optimized = optimized_conf.GetPositions()
try:
R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T)
except np.linalg.LinAlgError:
return None
coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze()
for i in range(optimized_mol.GetNumAtoms()):
x, y, z = coords_pred_optimized[i]
optimized_conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
return optimized_mol
def write_while_inferring(dataloader, model, args):
full_output_path = os.path.join(args.output_directory, "output.sdf")
full_failed_path = os.path.join(args.output_directory, "failed.txt")
full_success_path = os.path.join(args.output_directory, "success.txt")
has_any_work_been_done = False
w_or_a = "a" if args.skip_in_output else "w"
with torch.no_grad(), open(full_output_path, w_or_a) as file, open(
full_failed_path, "a") as failed_file, open(full_success_path, w_or_a) as success_file:
with Chem.SDWriter(file) as writer:
i = 0
total_ligs = len(dataloader.dataset)
for batch in dataloader:
i += args.batch_size
print(f"Entering batch ending in index {min(i, total_ligs)}/{len(dataloader.dataset)}")
ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices, failed_in_batch = batch
for failure in failed_in_batch:
if failure[1] == "Skipped":
continue
failed_file.write(f"{failure[0]} {failure[1]}")
failed_file.write("\n")
if ligs is None:
continue
lig_graphs = lig_graphs.to(args.device)
rec_graphs = rec_graphs.to(args.device)
geometry_graphs = geometry_graphs.to(args.device)
out_ligs, out_lig_coords, predictions, successes, failures = run_batch(model, ligs, lig_coords,
lig_graphs, rec_graphs,
geometry_graphs, true_indices)
opt_mols = [run_corrections(lig, lig_coord, prediction) for lig, lig_coord, prediction in zip(out_ligs, out_lig_coords, predictions)]
for mol, success in zip(opt_mols, successes):
if mol is not None:
writer.write(mol)
success_file.write(f"{success[0]} {success[1]}")
success_file.write("\n")
if not has_any_work_been_done:
has_any_work_been_done = True
else:
failures.append(success)
# print(f"written {mol.GetProp('_Name')} to output")
failures.sort(key = lambda x: x[0])
for failure in failures:
failed_file.write(f"{failure[0]} {failure[1]}")
failed_file.write("\n")
failed_file.flush()
success_file.flush()
writer.flush()
return has_any_work_been_done
def find_previous_work(args, create_output_dir = True):
if create_output_dir:
os.makedirs(args.output_directory, exist_ok = True)
success_path = os.path.join(args.output_directory, "success.txt")
failed_path = os.path.join(args.output_directory, "failed.txt")
if os.path.exists(success_path) and os.path.exists(failed_path) and args.skip_in_output:
with open(success_path) as successes, open(failed_path) as failures:
previous_work = successes.readlines()
previous_work += failures.readlines()
previous_work = set(map(lambda tup: int(tup.split(" ")[0]), previous_work))
print(f"Found {len(previous_work)} previously calculated ligands")
else:
previous_work = None
return previous_work
def main(arglist = None, lig_dataset = None, model = None, rec_graph = None, args = None, keeps = None):
if args is None:
args, cmdline_args = parse_arguments(arglist)
args = get_default_args(args, cmdline_args)
# print(args)
# sys.exit()
assert args.output_directory, "An output directory should be specified"
assert args.ligands_sdf, "No ligand sdf specified"
assert args.rec_pdb, "No protein specified"
seed_all(args.seed)
previous_work = find_previous_work(args, create_output_dir = True)
# if args.lig_slice is not None:
# lig_slice = tuple(map(int, args.lig_slice.split(",")))
# else:
# lig_slice = None
if rec_graph is None:
rec_graph = load_rec(args)
if model is None:
model = load_model(args)
if lig_dataset is None:
lig_dataset = multiple_ligands.Ligands(
args.ligands_sdf,
rec_graph,
args,
skips = previous_work,
# addH = args.addH,
# generate_conformer = args.use_rdkit_coords,
rdkit_seed = args.seed,
lig_load_workers = args.n_workers_data_load,
keeps = keeps,
lazy = args.lazy_dataload
)
if args.lazy_dataload:
args.n_workers_data_load = 0
lig_loader = DataLoader(lig_dataset, batch_size = args.batch_size, collate_fn = lig_dataset.collate, num_workers = args.n_workers_data_load)
return write_while_inferring(lig_loader, model, args)
if __name__ == '__main__':
main()