Skip to content

Commit

Permalink
fix use with 3D conformer featurizers
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom authored and Tom committed Aug 9, 2024
1 parent 62c4d77 commit 37ac692
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
26 changes: 19 additions & 7 deletions comptox_ai/chemical_featurizer/generate_vectors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from comptox_ai.db.graph_db import GraphDB
from molfeat.trans.fp import FPVecTransformer
from molfeat.trans import MoleculeTransformer
from molfeat.trans.pretrained.hf_transformers import PretrainedHFTransformer
from molfeat.trans.pretrained import PretrainedDGLTransformer
from rdkit import Chem, RDLogger
Expand All @@ -9,6 +9,7 @@
from collections import defaultdict
from itertools import chain


RDLogger.DisableLog("rdApp.*") # Disable rdkit warnings


Expand Down Expand Up @@ -201,6 +202,7 @@ def create_vector_table(
sanitize_smiles_flag=True,
rdkit_descriptors=True,
molfeat_descriptors=[],
dtype=np.float32,
use_original_chemical_ids_for_df_index=True,
):
"""
Expand All @@ -217,6 +219,8 @@ def create_vector_table(
Whether sanitize_smiles() should be run on the retrieved SMILES strings.
rdkit_descriptors : bool
Whether full set of rdkit_descriptors should be calculated and incorporated as vectors.
dtype : type
Data type of output df (e.g. float, np.float32, etc.).
molfeat_descriptors : List[str]
List of features to generate. For possible features, see https://molfeat.datamol.io/featurizers.
use_original_chemical_ids_for_df_index : bool
Expand Down Expand Up @@ -254,28 +258,36 @@ def create_vector_table(

vectors = []
df_column_names = []
conformer_3D_flag = False

if molfeat_descriptors:

for feature in molfeat_descriptors:
print(f"Calculating {feature} descriptors")

if feature in {"Roberta-Zinc480M-102M", "GPT2-Zinc480M-87M", "ChemGPT-1.2B", "ChemGPT-19M", "ChemGPT-4.7M", "MolT5", "ChemBERTa-77M-MTR", "ChemBERTa-77M-MLM"}:
featurizer = PretrainedHFTransformer(kind=feature, notation='smiles', dtype=float)
featurizer = PretrainedHFTransformer(kind=feature, notation='smiles', dtype=dtype)

elif feature in {"gin_supervised_masking", "gin_supervised_infomax", "gin_supervised_edgepred", "jtvae_zinc_no_kl", "gin_supervised_contextpred"}:
featurizer = PretrainedDGLTransformer(kind=feature, dtype=float)
featurizer = PretrainedDGLTransformer(kind=feature, dtype=dtype)

else:
mol_list = generate_3d_conformers(smiles_list)
featurizer = FPVecTransformer(kind=feature, dtype=np.float32, verbose=True)
vectors.append(featurizer(mol_list).tolist())
featurizer = MoleculeTransformer(featurizer=feature, dtype=dtype, verbose=True)
if feature in {"desc3D", "desc2D", "electroshape", "usrcat", "usr", "cats3d", "pharm3D-cats", "pharm3D-gobbi", "pharm3D-pmapper"}:
mol_list = generate_3d_conformers(smiles_list)
conformer_3D_flag = True

chemical_list = mol_list if conformer_3D_flag else smiles_list
vectors.append(featurizer(chemical_list).tolist())

df_column_names.append(feature)

if rdkit_descriptors:
print(f"Calculating rdkit descriptors")
mols = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
rdkit_features = np.array(
[
np.array(list(Descriptors.CalcMolDescriptors(mol).values()))
np.array(list(Descriptors.CalcMolDescriptors(mol).values()), dtype=dtype)
for mol in mols
]
)
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- scipy=1.7.1
- tqdm=4.62.2
- conda-forge::molfeat
- rdkit=2024.03.3
- conda-forge::rdkit
- conda-forge::ipdb
- transformers
- dgllife
Expand Down
Binary file modified tests/example_vector_table_original_chemical_ids_as_index.pkl
Binary file not shown.
Binary file modified tests/example_vector_table_smiles_as_index.pkl
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/test_chemical_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_create_vector_table_original_chemical_ids_as_index(self):
)
expected_output_df_original_chemical_ids_as_index = pd.read_pickle(df_file_path)
assert create_vector_table(
["Hydroxychloroquine", "Warfarin"], molfeat_descriptors=["maccs", "erg"]
["Hydroxychloroquine", "Warfarin"], molfeat_descriptors=['maccs', "usr", "Roberta-Zinc480M-102M", "gin_supervised_masking"]
).equals(expected_output_df_original_chemical_ids_as_index)

def test_create_vector_table_smiles_as_index(self):
Expand All @@ -159,6 +159,6 @@ def test_create_vector_table_smiles_as_index(self):
expected_output_df_smiles_as_index = pd.read_pickle(df_file_path)
assert create_vector_table(
["Hydroxychloroquine", "Warfarin"],
molfeat_descriptors=["maccs", "erg"],
molfeat_descriptors=['maccs', "usr", "Roberta-Zinc480M-102M", "gin_supervised_masking"],
use_original_chemical_ids_for_df_index=False,
).equals(expected_output_df_smiles_as_index)

0 comments on commit 37ac692

Please sign in to comment.