forked from HannesStark/EquiBind
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multiple_ligands.py
133 lines (115 loc) · 5.17 KB
/
multiple_ligands.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from torch.utils.data import Dataset
from commons.process_mols import get_geometry_graph, get_lig_graph_revised, get_rdkit_coords
from dgl import batch
from rdkit.Chem import SDMolSupplier, SanitizeMol, SanitizeFlags, PropertyMol, SmilesMolSupplier, AddHs, MultithreadedSmilesMolSupplier, MultithreadedSDMolSupplier
def safe_get_name(lig):
try:
return lig.GetProp("_Name")
except KeyError:
return None
class Ligands(Dataset):
def __init__(
self, ligpath, rec_graph, args,
skips = None, ext = None, addH = None,
rdkit_seed = None, lig_load_workers = 0,
generate_conformer = None,
):
self.ligpath = ligpath
self.rec_graph = rec_graph
self.args = args
self.dp = args.dataset_params
self.use_rdkit_coords = args.use_rdkit_coords
self.device = args.device
self.rdkit_seed = rdkit_seed
##Default argument handling
self.skips = skips
extensions_requiring_conformer_generation = ["smi"]
if ext is None:
try:
ext = ligpath.split(".")[-1]
except (AttributeError, KeyError):
ext = "sdf"
if addH is None:
if ext == "smi":
addH = True
else:
addH = False
self.addH = addH
if generate_conformer is None:
generate_conformer = ext in extensions_requiring_conformer_generation
self.generate_conformer = generate_conformer
if lig_load_workers > 0:
suppliers = {"sdf": MultithreadedSDMolSupplier, "smi": MultithreadedSmilesMolSupplier}
supp_kwargs = {"sdf": dict(sanitize = False, removeHs = False, numWriterThreads = lig_load_workers),
"smi": dict(sanitize = False, titleLine = False, numWriterThreads = lig_load_workers)}
self.supplier = suppliers[ext](ligpath, **supp_kwargs[ext])
print("start loading ligs")
self.ligs = [(lig, self.supplier.GetLastRecordId()) for lig in self.supplier]
self.ligs = sorted(self.ligs, key = lambda tup: tup[1])
self.ligs = list(zip(*self.ligs))[0][:-1]
print("finish loading ligs")
else:
suppliers = {"sdf": SDMolSupplier, "smi": SmilesMolSupplier}
supp_kwargs = {"sdf": dict(sanitize = False, removeHs = False),
"smi": dict(sanitize = False, titleLine = False)}
self.supplier = suppliers[ext](ligpath, **supp_kwargs[ext])
self.ligs = [lig for lig in self.supplier]
self._len = len(self.ligs)
def _process(self, lig):
if lig is None:
return None, None
if self.addH:
sanitize_succeded = (SanitizeMol(lig, catchErrors = True) is SanitizeFlags.SANITIZE_NONE)
if not sanitize_succeded:
return None, safe_get_name(lig)
lig = AddHs(lig)
if self.generate_conformer:
try:
get_rdkit_coords(lig, self.rdkit_seed)
except ValueError:
return None, safe_get_name(lig)
sanitize_succeded = (SanitizeMol(lig, catchErrors = True) is SanitizeFlags.SANITIZE_NONE)
if self.args.lig_name is not None:
lig.SetProp("_Name", lig.GetProp(self.args.lig_name))
if sanitize_succeded:
return lig, safe_get_name(lig)
else:
return None, safe_get_name(lig)
def __len__(self):
return self._len
def __getitem__(self, idx):
lig = self.ligs[idx]
if self.skips is not None and idx in self.skips:
return idx, "Skipped"
lig, name = self._process(lig)
if lig is None:
return idx, name
else:
lig = PropertyMol.PropertyMol(lig)
try:
lig_graph = get_lig_graph_revised(lig, safe_get_name(lig), max_neighbors=self.dp['lig_max_neighbors'],
use_rdkit_coords=False, radius=self.dp['lig_graph_radius'])
except AssertionError:
return idx, safe_get_name(lig)
geometry_graph = get_geometry_graph(lig) if self.dp['geometry_regularization'] else None
lig_graph.ndata["new_x"] = lig_graph.ndata["x"]
return lig, lig_graph.ndata["new_x"], lig_graph, self.rec_graph, geometry_graph, idx
@staticmethod
def collate(_batch):
sample_succeeded = lambda sample: not isinstance(sample[0], int)
sample_failed = lambda sample: isinstance(sample[0], int)
clean_batch = tuple(filter(sample_succeeded, _batch))
failed_in_batch = tuple(filter(sample_failed, _batch))
if len(clean_batch) == 0:
return None, None, None, None, None, None, failed_in_batch
ligs, lig_coords, lig_graphs, rec_graphs, geometry_graphs, true_indices = map(list, zip(*clean_batch))
output = (
ligs,
lig_coords,
batch(lig_graphs),
batch(rec_graphs),
batch(geometry_graphs) if geometry_graphs[0] is not None else None,
true_indices,
failed_in_batch
)
return output