diff --git a/src/openqdc/datasets/base.py b/src/openqdc/datasets/base.py index 392144d..96e0f0c 100644 --- a/src/openqdc/datasets/base.py +++ b/src/openqdc/datasets/base.py @@ -1,7 +1,9 @@ import os from os.path import join as p_join +from typing import Dict, List, Optional import numpy as np +import pandas as pd import torch from loguru import logger from sklearn.utils import Bunch @@ -18,7 +20,13 @@ from openqdc.utils.molecule import atom_table -def extract_entry(df, i, subset, energy_target_names, force_target_names=None): +def extract_entry( + df: pd.DataFrame, + i: int, + subset: str, + energy_target_names: List[str], + force_target_names: Optional[List[str]] = None, +) -> Dict[str, np.ndarray]: x = np.array([atom_table.GetAtomicNumber(s) for s in df["symbols"][i]]) xs = np.stack((x, np.zeros_like(x)), axis=-1) positions = df["geometry"][i].reshape((-1, 3)) @@ -42,18 +50,12 @@ def extract_entry(df, i, subset, energy_target_names, force_target_names=None): return res -def read_qc_archive_h5(raw_path, subset, energy_target_names, force_target_names): +def read_qc_archive_h5( + raw_path: str, subset: str, energy_target_names: List[str], force_target_names: List[str] +) -> List[Dict[str, np.ndarray]]: data = load_hdf5_file(raw_path) data_t = {k2: data[k1][k2][:] for k1 in data.keys() for k2 in data[k1].keys()} n = len(data_t["molecule_id"]) - # print(f"Reading {n} entries from {raw_path}") - # for k in data_t: - # print(f"Loaded {k} with shape {data_t[k].shape}, dtype {data_t[k].dtype}") - # if "Energy" in k: - # print(np.isnan(data_t[k]).mean(), f"{data_t[k][0]}") - - # print('\n'*3) - # exit() samples = [extract_entry(data_t, i, subset, energy_target_names, force_target_names) for i in tqdm(range(n))] return samples diff --git a/src/openqdc/datasets/geom.py b/src/openqdc/datasets/geom.py index 6af826e..eebcc66 100644 --- a/src/openqdc/datasets/geom.py +++ b/src/openqdc/datasets/geom.py @@ -1,4 +1,5 @@ from os.path import join as p_join +from typing import Dict import datamol as dm import numpy as np @@ -9,7 +10,7 @@ from openqdc.utils.molecule import get_atomic_number_and_charge -def read_mol(mol_id, mol_dict, base_path, partition): +def read_mol(mol_id: str, mol_dict, base_path: str, partition: str) -> Dict[str, np.ndarray]: """Read molecule from pickle file and return dict with conformers and energies Parameters @@ -20,15 +21,18 @@ def read_mol(mol_id, mol_dict, base_path, partition): Dictionary containing the pickle_path and smiles of the molecule base_path: str Path to the folder containing the pickle files + partition: str + Name of the dataset partition, one of ['qm9', 'drugs'] Returns ------- res: dict Dictionary containing the following keys: - - atomic_inputs: flatten np.ndarray of shape (M, 4) containing the atomic numbers and positions - - smiles: np.ndarray of shape (N,) containing the smiles of the molecule - - energies: np.ndarray of shape (N,1) containing the energies of the conformers - - n_atoms: np.ndarray of shape (N,) containing the number of atoms in each conformer + - atomic_inputs: flatten np.ndarray of shape (M, 5) containing the atomic numbers, charges and positions + - smiles: np.ndarray of shape (N,) containing the smiles of the molecule + - energies: np.ndarray of shape (N,1) containing the energies of the conformers + - n_atoms: np.ndarray of shape (N,) containing the number of atoms in each conformer + - subset: np.ndarray of shape (N,) containing the name of the dataset partition """ try: @@ -56,6 +60,22 @@ def read_mol(mol_id, mol_dict, base_path, partition): class GEOM(BaseDataset): + """ + The Geometric Ensemble Of Molecules (GEOM) dataset contains 37 million conformers for 133,000 molecules + from QM9, and 317,000 molecules with experimental data related to biophysics, physiology, + and physical chemistry. The dataset is generated using the GFN2-xTB semi-empirical method. + + Usage: + ```python + from openqdc.datasets import GEOM + dataset = GEOM() + ``` + + References: + - https://www.nature.com/articles/s41597-022-01288-4 + - https://github.com/learningmatter-mit/geom + """ + __name__ = "geom" __energy_methods__ = ["gfn2_xtb"] diff --git a/src/openqdc/datasets/molecule3d.py b/src/openqdc/datasets/molecule3d.py index 0d59400..e5870ca 100644 --- a/src/openqdc/datasets/molecule3d.py +++ b/src/openqdc/datasets/molecule3d.py @@ -1,5 +1,6 @@ from glob import glob from os.path import join as p_join +from typing import Dict, List import datamol as dm import numpy as np @@ -12,7 +13,26 @@ from openqdc.utils.molecule import get_atomic_number_and_charge -def read_mol(mol, energy): +def read_mol(mol: Chem.rdchem.Mol, energy: float) -> Dict[str, np.ndarray]: + """Read molecule (Chem.rdchem.Mol) and energy (float) and return dict with conformers and energies + + Parameters + ---------- + mol: Chem.rdchem.Mol + RDKit molecule + energy: float + Energy of the molecule + + Returns + ------- + res: dict + Dictionary containing the following keys: + - name: np.ndarray of shape (N,) containing the smiles of the molecule + - atomic_inputs: flatten np.ndarray of shape (M, 5) containing the atomic numbers, charges and positions + - energies: np.ndarray of shape (1,) containing the energy of the conformer + - n_atoms: np.ndarray of shape (1) containing the number of atoms in the conformer + - subset: np.ndarray of shape (1) containing "molecule3d" + """ smiles = dm.to_smiles(mol, explicit_hs=False) # subset = dm.to_smiles(dm.to_scaffold_murcko(mol, make_generic=True), explicit_hs=False) x = get_atomic_number_and_charge(mol) @@ -29,7 +49,8 @@ def read_mol(mol, energy): return res -def _read_sdf(sdf_path, properties_path): +def _read_sdf(sdf_path: str, properties_path: str) -> List[Dict[str, np.ndarray]]: + """Reads the sdf path and properties file.""" properties = pd.read_csv(properties_path, dtype={"cid": str}) properties.drop_duplicates(subset="cid", inplace=True, keep="first") xys = properties[["cid", "scf energy"]] @@ -45,6 +66,22 @@ def _read_sdf(sdf_path, properties_path): class Molecule3D(BaseDataset): + """ + Molecule3D dataset consists of 3,899,647 molecules with ground state geometries and energies + calculated at B3LYP/6-31G* level of theory. The molecules are extracted from the + PubChem database and cleaned by removing invalid molecule files. + + Usage: + ```python + from openqdc.datasets import Molecule3D + dataset = Molecule3D() + ``` + + References: + - https://arxiv.org/abs/2110.01717 + - https://github.com/divelab/MoleculeX + """ + __name__ = "molecule3d" __energy_methods__ = ["b3lyp_6-31g*"] diff --git a/src/openqdc/datasets/orbnet_denali.py b/src/openqdc/datasets/orbnet_denali.py index 452cce1..2d8b093 100644 --- a/src/openqdc/datasets/orbnet_denali.py +++ b/src/openqdc/datasets/orbnet_denali.py @@ -1,4 +1,5 @@ from os.path import join as p_join +from typing import Dict, List import datamol as dm import numpy as np @@ -9,7 +10,7 @@ from openqdc.utils.molecule import atom_table -def read_mol(mol_id, conf_dict, base_path, energy_target_names): +def read_mol(mol_id, conf_dict, base_path, energy_target_names: List[str]) -> Dict[str, np.ndarray]: res = [] for conf_id, conf_label in conf_dict.items(): try: @@ -34,6 +35,23 @@ def read_mol(mol_id, conf_dict, base_path, energy_target_names): class OrbnetDenali(BaseDataset): + """ + Orbnet Denali is a collection of 2.3 million conformers from 212,905 unique molecules. It performs + DFT (ωB97X-D3/def2-TZVP) calculations on molecules and geometries consisting of organic molecules + and chemistries, with protonation and tautomeric states, non-covalent interactions, common salts, + and counterions, spanning the most common elements in bio and organic chemistry. + + Usage: + ```python + from openqdc.datasets import OrbnetDenali + dataset = OrbnetDenali() + ``` + + References: + - https://arxiv.org/pdf/2107.00299.pdf + - https://figshare.com/articles/dataset/OrbNet_Denali_Training_Data/14883867 + """ + __name__ = "orbnet_denali" __energy_methods__ = ["wb97x-d3_tz", "gfn1_xtb"] @@ -53,13 +71,6 @@ def read_raw_entries(self): for mol_id, group in df.groupby("mol_id") } - # print(df.head()) - # tmp = df.to_dict('index') - # for i, k in enumerate(tmp): - # print(k, tmp[k]) - # if i > 10: - # break - # exit() fn = lambda x: read_mol(x[0], x[1], self.root, self.energy_target_names) res = dm.parallelized(fn, list(labels.items()), scheduler="threads", n_jobs=-1, progress=True) samples = sum(res, []) diff --git a/src/openqdc/datasets/qmugs.py b/src/openqdc/datasets/qmugs.py index b528f42..d15d83b 100644 --- a/src/openqdc/datasets/qmugs.py +++ b/src/openqdc/datasets/qmugs.py @@ -36,6 +36,22 @@ def read_mol(mol_dir): class QMugs(BaseDataset): + """ + The QMugs dataset contains 2 million conformers for 665k biologically and pharmacologically relevant molecules + extracted from the ChEMBL database. The atomic and molecular properties are calculated using both, + semi-empirical methods (GFN2-xTB) and DFT method (ωB97X-D/def2-SVP). + + Usage: + ```python + from openqdc.datasets import QMugs + dataset = QMugs() + ``` + + References: + - https://www.nature.com/articles/s41597-022-01390-7#ethics + - https://www.research-collection.ethz.ch/handle/20.500.11850/482129 + """ + __name__ = "qmugs" __energy_methods__ = ["gfn2_xtb", "b3lyp/6-31g*"]