-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvironment.py
286 lines (245 loc) · 9.26 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import itertools
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Descriptors import qed, MolLogP
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.FilterCatalog import FilterCatalogParams, FilterCatalog
import copy
import networkx as nx
import math
import random
import time
#import matplotlib.pyplot as plt
import csv
from contextlib import contextmanager
import sys, os
from sascorer import calculateScore
# most of the codes here are taken from GCPN's open-source implementation
'''
Environment Usage
plogp: penalized_logp(mol)
logp: MolLogP(mol)
qed: qed(mol)
sa: calculateScore(mol)
mw: rdMolDescriptors.CalcExactMolWt(mol)
'''
def convert_radical_electrons_to_hydrogens(mol):
"""
Converts radical electrons in a molecule into bonds to hydrogens. Only
use this if molecule is valid. Results a new mol object
:param mol: rdkit mol object
:return: rdkit mol object
"""
m = copy.deepcopy(mol)
if Chem.Descriptors.NumRadicalElectrons(m) == 0: # not a radical
return m
else: # a radical
print('converting radical electrons to H')
for a in m.GetAtoms():
num_radical_e = a.GetNumRadicalElectrons()
if num_radical_e > 0:
a.SetNumRadicalElectrons(0)
a.SetNumExplicitHs(num_radical_e)
return m
def check_chemical_validity(mol):
"""
Checks the chemical validity of the mol object. Existing mol object is
not modified. Radicals pass this test.
:return: True if chemically valid, False otherwise
"""
s = Chem.MolToSmiles(mol, isomericSmiles=True)
m = Chem.MolFromSmiles(s) # implicitly performs sanitization
if m:
return True
else:
return False
def check_valency(mol):
"""
Checks that no atoms in the mol have exceeded their possible
valency
:return: True if no valency issues, False otherwise
"""
try:
Chem.SanitizeMol(mol,
sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True
except ValueError:
return False
def get_final_smiles(mol):
"""
Returns a SMILES of the final molecule. Converts any radical
electrons into hydrogens. Works only if molecule is valid
:return: SMILES
"""
m = convert_radical_electrons_to_hydrogens(mol)
return Chem.MolToSmiles(m, isomericSmiles=True)
def get_final_mol(mol):
"""
Returns a rdkit mol object of the final molecule. Converts any radical
electrons into hydrogens. Works only if molecule is valid
:return: SMILES
"""
m = convert_radical_electrons_to_hydrogens(mol)
return m
def add_hydrogen(mol):
s = Chem.MolToSmiles(mol, isomericSmiles=True)
return Chem.MolFromSmiles(s)
def calculate_min_plogp(mol):
p1 = penalized_logp(mol)
s1 = Chem.MolToSmiles(mol, isomericSmiles=True)
s2 = Chem.MolToSmiles(mol, isomericSmiles=False)
mol1 = Chem.MolFromSmiles(s1)
mol2 = Chem.MolFromSmiles(s2)
p2 = penalized_logp(mol1)
p3 = penalized_logp(mol2)
final_p = min(p1, p2)
final_p = min(final_p, p3)
return final_p
def penalized_logp(mol):
"""
Reward that consists of log p penalized by SA and # long cycles,
as described in (Kusner et al. 2017). Scores are normalized based on the
statistics of 250k_rndm_zinc_drugs_clean.smi dataset
:param mol: rdkit mol object
:return: float
"""
# normalization constants, statistics from 250k_rndm_zinc_drugs_clean.smi
logP_mean = 2.4570953396190123
logP_std = 1.434324401111988
SA_mean = -3.0525811293166134
SA_std = 0.8335207024513095
cycle_mean = -0.0485696876403053
cycle_std = 0.2860212110245455
log_p = MolLogP(mol)
SA = -calculateScore(mol)
# cycle score
cycle_list = nx.cycle_basis(nx.Graph(
Chem.rdmolops.GetAdjacencyMatrix(mol)))
if len(cycle_list) == 0:
cycle_length = 0
else:
cycle_length = max([len(j) for j in cycle_list])
if cycle_length <= 6:
cycle_length = 0
else:
cycle_length = cycle_length - 6
cycle_score = -cycle_length
normalized_log_p = (log_p - logP_mean) / logP_std
normalized_SA = (SA - SA_mean) / SA_std
normalized_cycle = (cycle_score - cycle_mean) / cycle_std
return normalized_log_p + normalized_SA + normalized_cycle
def steric_strain_filter(mol, cutoff=0.82, max_attempts_embed=20, max_num_iters=200):
"""
Flags molecules based on a steric energy cutoff after max_num_iters
iterations of MMFF94 forcefield minimization. Cutoff is based on average
angle bend strain energy of molecule
:param mol: rdkit mol object
:param cutoff: kcal/mol per angle . If minimized energy is above this
threshold, then molecule fails the steric strain filter
:param max_attempts_embed: number of attempts to generate initial 3d
coordinates
:param max_num_iters: number of iterations of forcefield minimization
:return: True if molecule could be successfully minimized, and resulting
energy is below cutoff, otherwise False
"""
# check for the trivial cases of a single atom or only 2 atoms, in which
# case there is no angle bend strain energy (as there are no angles!)
if mol.GetNumAtoms() <= 2:
return True
# make copy of input mol and add hydrogens
m = copy.deepcopy(mol)
m_h = Chem.AddHs(m)
# generate an initial 3d conformer
try:
flag = AllChem.EmbedMolecule(m_h, maxAttempts=max_attempts_embed)
if flag == -1:
# print("Unable to generate 3d conformer")
return False
except: # to catch error caused by molecules such as C=[SH]1=C2OC21ON(N)OC(=O)NO
# print("Unable to generate 3d conformer")
return False
# set up the forcefield
AllChem.MMFFSanitizeMolecule(m_h)
if AllChem.MMFFHasAllMoleculeParams(m_h):
mmff_props = AllChem.MMFFGetMoleculeProperties(m_h)
try: # to deal with molecules such as CNN1NS23(=C4C5=C2C(=C53)N4Cl)S1
ff = AllChem.MMFFGetMoleculeForceField(m_h, mmff_props)
except:
# print("Unable to get forcefield or sanitization error")
return False
else:
# print("Unrecognized atom type")
return False
# minimize steric energy
try:
ff.Minimize(maxIts=max_num_iters)
except:
# print("Minimization error")
return False
# ### debug ###
# min_e = ff.CalcEnergy()
# print("Minimized energy: {}".format(min_e))
# ### debug ###
# get the angle bend term contribution to the total molecule strain energy
mmff_props.SetMMFFBondTerm(False)
mmff_props.SetMMFFAngleTerm(True)
mmff_props.SetMMFFStretchBendTerm(False)
mmff_props.SetMMFFOopTerm(False)
mmff_props.SetMMFFTorsionTerm(False)
mmff_props.SetMMFFVdWTerm(False)
mmff_props.SetMMFFEleTerm(False)
ff = AllChem.MMFFGetMoleculeForceField(m_h, mmff_props)
min_angle_e = ff.CalcEnergy()
# print("Minimized angle bend energy: {}".format(min_angle_e))
# find number of angles in molecule
# from molecule... This is too hacky
num_atoms = m_h.GetNumAtoms()
atom_indices = range(num_atoms)
angle_atom_triplets = itertools.permutations(atom_indices, 3) # get all
# possible 3 atom indices groups. Currently, each angle is represented by
# 2 duplicate groups. Should remove duplicates here to be more efficient
double_num_angles = 0
for triplet in list(angle_atom_triplets):
if mmff_props.GetMMFFAngleBendParams(m_h, *triplet):
double_num_angles += 1
num_angles = double_num_angles / 2 # account for duplicate angles
# print("Number of angles: {}".format(num_angles))
avr_angle_e = min_angle_e / num_angles
# print("Average minimized angle bend energy: {}".format(avr_angle_e))
# ### debug ###
# for i in range(7):
# termList = [['BondStretch', False], ['AngleBend', False],
# ['StretchBend', False], ['OopBend', False],
# ['Torsion', False],
# ['VdW', False], ['Electrostatic', False]]
# termList[i][1] = True
# mmff_props.SetMMFFBondTerm(termList[0][1])
# mmff_props.SetMMFFAngleTerm(termList[1][1])
# mmff_props.SetMMFFStretchBendTerm(termList[2][1])
# mmff_props.SetMMFFOopTerm(termList[3][1])
# mmff_props.SetMMFFTorsionTerm(termList[4][1])
# mmff_props.SetMMFFVdWTerm(termList[5][1])
# mmff_props.SetMMFFEleTerm(termList[6][1])
# ff = AllChem.MMFFGetMoleculeForceField(m_h, mmff_props)
# print('{0:>16s} energy: {1:12.4f} kcal/mol'.format(termList[i][0],
# ff.CalcEnergy()))
# ## end debug ###
if avr_angle_e < cutoff:
return True
else:
return False
### YES/NO filters ###
def zinc_molecule_filter(mol):
"""
Flags molecules based on problematic functional groups as
provided set of ZINC rules from
http://blaster.docking.org/filtering/rules_default.txt.
:param mol: rdkit mol object
:return: Returns True if molecule is okay (ie does not match any of
therules), False if otherwise
"""
params = FilterCatalogParams()
params.AddCatalog(FilterCatalogParams.FilterCatalogs.ZINC)
catalog = FilterCatalog(params)
return not catalog.HasMatch(mol)