Skip to content

Commit

Permalink
Updated docs for geom, molecule3d, orbnet_denali, qmugs
Browse files Browse the repository at this point in the history
  • Loading branch information
shenoynikhil committed Oct 5, 2023
1 parent 7045e6e commit 99a3506
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 25 deletions.
22 changes: 12 additions & 10 deletions src/openqdc/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from os.path import join as p_join
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
import torch
from loguru import logger
from sklearn.utils import Bunch
Expand All @@ -18,7 +20,13 @@
from openqdc.utils.molecule import atom_table


def extract_entry(df, i, subset, energy_target_names, force_target_names=None):
def extract_entry(
df: pd.DataFrame,
i: int,
subset: str,
energy_target_names: List[str],
force_target_names: Optional[List[str]] = None,
) -> Dict[str, np.ndarray]:
x = np.array([atom_table.GetAtomicNumber(s) for s in df["symbols"][i]])
xs = np.stack((x, np.zeros_like(x)), axis=-1)
positions = df["geometry"][i].reshape((-1, 3))
Expand All @@ -42,18 +50,12 @@ def extract_entry(df, i, subset, energy_target_names, force_target_names=None):
return res


def read_qc_archive_h5(raw_path, subset, energy_target_names, force_target_names):
def read_qc_archive_h5(
raw_path: str, subset: str, energy_target_names: List[str], force_target_names: List[str]
) -> List[Dict[str, np.ndarray]]:
data = load_hdf5_file(raw_path)
data_t = {k2: data[k1][k2][:] for k1 in data.keys() for k2 in data[k1].keys()}
n = len(data_t["molecule_id"])
# print(f"Reading {n} entries from {raw_path}")
# for k in data_t:
# print(f"Loaded {k} with shape {data_t[k].shape}, dtype {data_t[k].dtype}")
# if "Energy" in k:
# print(np.isnan(data_t[k]).mean(), f"{data_t[k][0]}")

# print('\n'*3)
# exit()

samples = [extract_entry(data_t, i, subset, energy_target_names, force_target_names) for i in tqdm(range(n))]
return samples
Expand Down
30 changes: 25 additions & 5 deletions src/openqdc/datasets/geom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os.path import join as p_join
from typing import Dict

import datamol as dm
import numpy as np
Expand All @@ -9,7 +10,7 @@
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_mol(mol_id, mol_dict, base_path, partition):
def read_mol(mol_id: str, mol_dict, base_path: str, partition: str) -> Dict[str, np.ndarray]:
"""Read molecule from pickle file and return dict with conformers and energies
Parameters
Expand All @@ -20,15 +21,18 @@ def read_mol(mol_id, mol_dict, base_path, partition):
Dictionary containing the pickle_path and smiles of the molecule
base_path: str
Path to the folder containing the pickle files
partition: str
Name of the dataset partition, one of ['qm9', 'drugs']
Returns
-------
res: dict
Dictionary containing the following keys:
- atomic_inputs: flatten np.ndarray of shape (M, 4) containing the atomic numbers and positions
- smiles: np.ndarray of shape (N,) containing the smiles of the molecule
- energies: np.ndarray of shape (N,1) containing the energies of the conformers
- n_atoms: np.ndarray of shape (N,) containing the number of atoms in each conformer
- atomic_inputs: flatten np.ndarray of shape (M, 5) containing the atomic numbers, charges and positions
- smiles: np.ndarray of shape (N,) containing the smiles of the molecule
- energies: np.ndarray of shape (N,1) containing the energies of the conformers
- n_atoms: np.ndarray of shape (N,) containing the number of atoms in each conformer
- subset: np.ndarray of shape (N,) containing the name of the dataset partition
"""

try:
Expand Down Expand Up @@ -56,6 +60,22 @@ def read_mol(mol_id, mol_dict, base_path, partition):


class GEOM(BaseDataset):
"""
The Geometric Ensemble Of Molecules (GEOM) dataset contains 37 million conformers for 133,000 molecules
from QM9, and 317,000 molecules with experimental data related to biophysics, physiology,
and physical chemistry. The dataset is generated using the GFN2-xTB semi-empirical method.
Usage:
```python
from openqdc.datasets import GEOM
dataset = GEOM()
```
References:
- https://www.nature.com/articles/s41597-022-01288-4
- https://github.com/learningmatter-mit/geom
"""

__name__ = "geom"
__energy_methods__ = ["gfn2_xtb"]

Expand Down
41 changes: 39 additions & 2 deletions src/openqdc/datasets/molecule3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from glob import glob
from os.path import join as p_join
from typing import Dict, List

import datamol as dm
import numpy as np
Expand All @@ -12,7 +13,26 @@
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_mol(mol, energy):
def read_mol(mol: Chem.rdchem.Mol, energy: float) -> Dict[str, np.ndarray]:
"""Read molecule (Chem.rdchem.Mol) and energy (float) and return dict with conformers and energies
Parameters
----------
mol: Chem.rdchem.Mol
RDKit molecule
energy: float
Energy of the molecule
Returns
-------
res: dict
Dictionary containing the following keys:
- name: np.ndarray of shape (N,) containing the smiles of the molecule
- atomic_inputs: flatten np.ndarray of shape (M, 5) containing the atomic numbers, charges and positions
- energies: np.ndarray of shape (1,) containing the energy of the conformer
- n_atoms: np.ndarray of shape (1) containing the number of atoms in the conformer
- subset: np.ndarray of shape (1) containing "molecule3d"
"""
smiles = dm.to_smiles(mol, explicit_hs=False)
# subset = dm.to_smiles(dm.to_scaffold_murcko(mol, make_generic=True), explicit_hs=False)
x = get_atomic_number_and_charge(mol)
Expand All @@ -29,7 +49,8 @@ def read_mol(mol, energy):
return res


def _read_sdf(sdf_path, properties_path):
def _read_sdf(sdf_path: str, properties_path: str) -> List[Dict[str, np.ndarray]]:
"""Reads the sdf path and properties file."""
properties = pd.read_csv(properties_path, dtype={"cid": str})
properties.drop_duplicates(subset="cid", inplace=True, keep="first")
xys = properties[["cid", "scf energy"]]
Expand All @@ -45,6 +66,22 @@ def _read_sdf(sdf_path, properties_path):


class Molecule3D(BaseDataset):
"""
Molecule3D dataset consists of 3,899,647 molecules with ground state geometries and energies
calculated at B3LYP/6-31G* level of theory. The molecules are extracted from the
PubChem database and cleaned by removing invalid molecule files.
Usage:
```python
from openqdc.datasets import Molecule3D
dataset = Molecule3D()
```
References:
- https://arxiv.org/abs/2110.01717
- https://github.com/divelab/MoleculeX
"""

__name__ = "molecule3d"
__energy_methods__ = ["b3lyp_6-31g*"]

Expand Down
27 changes: 19 additions & 8 deletions src/openqdc/datasets/orbnet_denali.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os.path import join as p_join
from typing import Dict, List

import datamol as dm
import numpy as np
Expand All @@ -9,7 +10,7 @@
from openqdc.utils.molecule import atom_table


def read_mol(mol_id, conf_dict, base_path, energy_target_names):
def read_mol(mol_id, conf_dict, base_path, energy_target_names: List[str]) -> Dict[str, np.ndarray]:
res = []
for conf_id, conf_label in conf_dict.items():
try:
Expand All @@ -34,6 +35,23 @@ def read_mol(mol_id, conf_dict, base_path, energy_target_names):


class OrbnetDenali(BaseDataset):
"""
Orbnet Denali is a collection of 2.3 million conformers from 212,905 unique molecules. It performs
DFT (ωB97X-D3/def2-TZVP) calculations on molecules and geometries consisting of organic molecules
and chemistries, with protonation and tautomeric states, non-covalent interactions, common salts,
and counterions, spanning the most common elements in bio and organic chemistry.
Usage:
```python
from openqdc.datasets import OrbnetDenali
dataset = OrbnetDenali()
```
References:
- https://arxiv.org/pdf/2107.00299.pdf
- https://figshare.com/articles/dataset/OrbNet_Denali_Training_Data/14883867
"""

__name__ = "orbnet_denali"
__energy_methods__ = ["wb97x-d3_tz", "gfn1_xtb"]

Expand All @@ -53,13 +71,6 @@ def read_raw_entries(self):
for mol_id, group in df.groupby("mol_id")
}

# print(df.head())
# tmp = df.to_dict('index')
# for i, k in enumerate(tmp):
# print(k, tmp[k])
# if i > 10:
# break
# exit()
fn = lambda x: read_mol(x[0], x[1], self.root, self.energy_target_names)
res = dm.parallelized(fn, list(labels.items()), scheduler="threads", n_jobs=-1, progress=True)
samples = sum(res, [])
Expand Down
16 changes: 16 additions & 0 deletions src/openqdc/datasets/qmugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def read_mol(mol_dir):


class QMugs(BaseDataset):
"""
The QMugs dataset contains 2 million conformers for 665k biologically and pharmacologically relevant molecules
extracted from the ChEMBL database. The atomic and molecular properties are calculated using both,
semi-empirical methods (GFN2-xTB) and DFT method (ωB97X-D/def2-SVP).
Usage:
```python
from openqdc.datasets import QMugs
dataset = QMugs()
```
References:
- https://www.nature.com/articles/s41597-022-01390-7#ethics
- https://www.research-collection.ethz.ch/handle/20.500.11850/482129
"""

__name__ = "qmugs"
__energy_methods__ = ["gfn2_xtb", "b3lyp/6-31g*"]

Expand Down

0 comments on commit 99a3506

Please sign in to comment.