Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change format #30

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_lazy_imports_obj = {}

_lazy_imports_mod = {"datasets": "openqdc.datamodule", "utils": "openqdc.utils"}
_lazy_imports_mod = {"datasets": "openqdc.datasets", "utils": "openqdc.utils"}


def __getattr__(name):
Expand Down
18 changes: 18 additions & 0 deletions src/openqdc/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class ANI1(BaseDataset):
def root(self):
return p_join(get_local_cache(), "ani")

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return "-".join(x.decode("ascii").split("-")[:-1])

@property
def preprocess_path(self):
path = p_join(self.root, "preprocessed", self.__name__)
Expand Down Expand Up @@ -89,6 +95,12 @@ class ANI1CCX(ANI1):
__force_methods__ = []
force_target_names = []

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return x


class ANI1X(ANI1):
"""
Expand Down Expand Up @@ -145,3 +157,9 @@ class ANI1X(ANI1):

def convert_forces(self, x):
return super().convert_forces(x) * 0.529177249 # correct the Dataset error

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return x
77 changes: 40 additions & 37 deletions src/openqdc/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle as pkl
from os.path import join as p_join
from typing import Dict, List, Optional, Union

Expand All @@ -24,7 +25,7 @@
push_remote,
set_cache_dir,
)
from openqdc.utils.molecule import atom_table
from openqdc.utils.molecule import atom_table, z_to_formula
from openqdc.utils.package_utils import requires_package
from openqdc.utils.units import get_conversion

Expand All @@ -43,7 +44,7 @@ def extract_entry(

res = dict(
name=np.array([df["name"][i]]),
subset=np.array([subset]),
subset=np.array([subset if subset is not None else z_to_formula(x)]),
energies=energies.reshape((1, -1)).astype(np.float32),
atomic_inputs=np.concatenate((xs, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
Expand All @@ -64,8 +65,8 @@ def read_qc_archive_h5(
) -> List[Dict[str, np.ndarray]]:
data = load_hdf5_file(raw_path)
data_t = {k2: data[k1][k2][:] for k1 in data.keys() for k2 in data[k1].keys()}
n = len(data_t["molecule_id"])

n = len(data_t["molecule_id"])
samples = [extract_entry(data_t, i, subset, energy_target_names, force_target_names) for i in tqdm(range(n))]
return samples

Expand Down Expand Up @@ -96,9 +97,6 @@ def __init__(
self._set_units(energy_unit, distance_unit)
if not self.is_preprocessed():
logger.info("This dataset not available. Please open an issue on Github for the team to look into it.")
# entries = self.read_raw_entries()
# res = self.collate_list(entries)
# self.save_preprocess(res)
else:
self.read_preprocess(overwrite_local_cache=overwrite_local_cache)
self._set_isolated_atom_energies()
Expand All @@ -107,12 +105,12 @@ def __init__(
def numbers(self):
if hasattr(self, "_numbers"):
return self._numbers
self._numbers = np.array(list(set(self.data["atomic_inputs"][..., 0])), dtype=np.int32)
self._numbers = np.unique(self.data["atomic_inputs"][..., 0]).astype(np.int32)
return self._numbers

@property
def chemical_species(self):
return [chemical_symbols[z] for z in self.numbers]
return np.array(chemical_symbols)[self.numbers]

@property
def energy_unit(self):
Expand Down Expand Up @@ -211,10 +209,11 @@ def collate_list(self, list_entries):
# concatenate entries
res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]}

csum = np.cumsum(res.pop("n_atoms"))
csum = np.cumsum(res.get("n_atoms"))
x = np.zeros((csum.shape[0], 2), dtype=np.int32)
x[1:, 0], x[:, 1] = csum[:-1], csum
res["position_idx_range"] = x

return res

def save_preprocess(self, data_dict):
Expand All @@ -228,12 +227,13 @@ def save_preprocess(self, data_dict):
push_remote(local_path, overwrite=True)

# save smiles and subset
local_path = p_join(self.preprocess_path, "props.pkl")
for key in ["name", "subset"]:
local_path = p_join(self.preprocess_path, f"{key}.npz")
uniques, inv_indices = np.unique(data_dict[key], return_inverse=True)
with open(local_path, "wb") as f:
np.savez_compressed(f, uniques=uniques, inv_indices=inv_indices)
push_remote(local_path)
data_dict[key] = np.unique(data_dict[key], return_inverse=True)

with open(local_path, "wb") as f:
pkl.dump(data_dict, f)
push_remote(local_path, overwrite=True)

def read_preprocess(self, overwrite_local_cache=False):
logger.info("Reading preprocessed data")
Expand All @@ -247,32 +247,29 @@ def read_preprocess(self, overwrite_local_cache=False):
for key in self.data_keys:
filename = p_join(self.preprocess_path, f"{key}.mmap")
pull_locally(filename, overwrite=overwrite_local_cache)
self.data[key] = np.memmap(
filename,
mode="r",
dtype=self.data_types[key],
).reshape(self.data_shapes[key])
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(self.data_shapes[key])

filename = p_join(self.preprocess_path, "props.pkl")
pull_locally(filename, overwrite=overwrite_local_cache)
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in ["name", "subset", "n_atoms"]:
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
else:
self.data[key] = x

for key in self.data:
print(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")

for key in ["name", "subset"]:
filename = p_join(self.preprocess_path, f"{key}.npz")
pull_locally(filename)
self.data[key] = dict()
with open(filename, "rb") as f:
tmp = np.load(f)
for k in tmp:
self.data[key][k] = tmp[k]
print(f"Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}")

def is_preprocessed(self):
predicats = [copy_exists(p_join(self.preprocess_path, f"{key}.mmap")) for key in self.data_keys]
predicats += [copy_exists(p_join(self.preprocess_path, f"{x}.npz")) for x in ["name", "subset"]]
predicats += [copy_exists(p_join(self.preprocess_path, "props.pkl"))]
return all(predicats)

def preprocess(self):
if not self.is_preprocessed():
def preprocess(self, overwrite=False):
if overwrite or not self.is_preprocessed():
entries = self.read_raw_entries()
res = self.collate_list(entries)
self.save_preprocess(res)
Expand Down Expand Up @@ -305,7 +302,7 @@ def get_ase_atoms(self, idx: int, ext=True):

@requires_package("dscribe")
@requires_package("datamol")
def chemical_space(
def soap_descriptors(
self,
n_samples: Optional[Union[List[int], int]] = None,
return_idxs: bool = True,
Expand Down Expand Up @@ -350,7 +347,7 @@ def chemical_space(
idxs = list(range(len(self)))
elif isinstance(n_samples, int):
idxs = np.random.choice(len(self), size=n_samples, replace=False)
elif isinstance(n_samples, list):
else: # list, set, np.ndarray
idxs = n_samples
datum = {}
r_cut = soap_kwargs.pop("r_cut", 5.0)
Expand Down Expand Up @@ -383,7 +380,7 @@ def wrapper(idx):
entry = self.get_ase_atoms(idx, ext=False)
return soap.create(entry, centers=entry.positions)

descr = dm.parallelized(wrapper, idxs, progress=progress, scheduler="threads")
descr = dm.parallelized(wrapper, idxs, progress=progress, scheduler="threads", n_jobs=-1)
datum["soap"] = np.vstack(descr)
if return_idxs:
datum["idxs"] = idxs
Expand All @@ -392,6 +389,12 @@ def wrapper(idx):
def __len__(self):
return self.data["energies"].shape[0]

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return x

def __getitem__(self, idx: int):
shift = IsolatedAtomEnergyFactory.max_charge
p_start, p_end = self.data["position_idx_range"][idx]
Expand All @@ -402,8 +405,8 @@ def __getitem__(self, idx: int):
self.convert_distance(np.array(input[:, -3:], dtype=np.float32)),
self.convert_energy(np.array(self.data["energies"][idx], dtype=np.float32)),
)
name = self.data["name"]["uniques"][self.data["name"]["inv_indices"][idx]]
subset = self.data["subset"]["uniques"][self.data["subset"]["inv_indices"][idx]]
name = self.__smiles_converter__(self.data["name"][idx])
subset = self.data["subset"][idx]

if "forces" in self.data:
forces = self.convert_forces(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))
Expand Down
14 changes: 10 additions & 4 deletions src/openqdc/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class COMP6(BaseDataset):
"pbe-d3bj/def2-tzvp",
"pbe/def2-tzvp",
"svwn/def2-tzvp",
"wb97m-d3bj/def2-tzvp",
"wb97m/def2-tzvp",
# "wb97m-d3bj/def2-tzvp",
# "wb97m/def2-tzvp",
]

energy_target_names = [
Expand All @@ -47,8 +47,8 @@ class COMP6(BaseDataset):
"PBE-D3M(BJ):def2-tzvp",
"PBE:def2-tzvp",
"SVWN:def2-tzvp",
"WB97M-D3(BJ):def2-tzvp",
"WB97M:def2-tzvp",
# "WB97M-D3(BJ):def2-tzvp",
# "WB97M:def2-tzvp",
]

__force_methods__ = [
Expand All @@ -59,6 +59,12 @@ class COMP6(BaseDataset):
"Gradient",
]

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return "-".join(x.decode("ascii").split("_")[:-1])

def read_raw_entries(self):
samples = []
for subset in ["ani_md", "drugbank", "gdb7_9", "gdb10_13", "s66x8", "tripeptides"]:
Expand Down
6 changes: 6 additions & 0 deletions src/openqdc/datasets/iso_17.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class ISO17(BaseDataset):
__distance_unit__ = "bohr" # bohr
__forces_unit__ = "ev/bohr"

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return "-".join(x.decode("ascii").split("_")[:-1])

def read_raw_entries(self):
raw_path = p_join(self.root, "iso_17.h5")
samples = read_qc_archive_h5(raw_path, "iso_17", self.energy_target_names, self.force_target_names)
Expand Down
26 changes: 19 additions & 7 deletions src/openqdc/datasets/nabladft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,46 @@

import datamol as dm
import numpy as np
from tqdm import tqdm
import pandas as pd

from openqdc.datasets.base import BaseDataset
from openqdc.utils.molecule import z_to_formula
from openqdc.utils.package_utils import requires_package


def to_mol(entry) -> Dict[str, np.ndarray]:
def to_mol(entry, metadata) -> Dict[str, np.ndarray]:
Z, R, E, F = entry[:4]
C = np.zeros_like(Z)
E[0] = metadata["DFT TOTAL ENERGY"]

res = dict(
atomic_inputs=np.concatenate((Z[:, None], C[:, None], R), axis=-1).astype(np.float32),
name=np.array([""]),
name=np.array([metadata["SMILES"]]),
energies=E[:, None].astype(np.float32),
forces=F[:, :, None].astype(np.float32),
n_atoms=np.array([Z.shape[0]], dtype=np.int32),
subset=np.array(["nabla"]),
subset=np.array([z_to_formula(Z)]),
)

return res


@requires_package("nablaDFT")
def read_chunk_from_db(raw_path, start_idx, stop_idx, step_size=1000):
def read_chunk_from_db(raw_path, start_idx, stop_idx, labels, step_size=1000):
from nablaDFT.dataset import HamiltonianDatabase

print(f"Loading from {start_idx} to {stop_idx}")
db = HamiltonianDatabase(raw_path)
idxs = list(np.arange(start_idx, stop_idx))
n, s = len(idxs), step_size

samples = [to_mol(entry) for i in tqdm(range(0, n, s)) for entry in db[idxs[i : i + s]]]
cursor = db._get_connection().cursor()
data_idxs = cursor.execute("""SELECT * FROM dataset_ids WHERE id IN (""" + str(idxs)[1:-1] + ")").fetchall()
c_idxs = [tuple(x[1:]) for x in data_idxs]

samples = [
to_mol(entry, labels[c_idxs[i + j]]) for i in range(0, n, s) for j, entry in enumerate(db[idxs[i : i + s]])
]
return samples


Expand Down Expand Up @@ -68,12 +76,16 @@ class NablaDFT(BaseDataset):
def read_raw_entries(self):
from nablaDFT.dataset import HamiltonianDatabase

label_path = p_join(self.root, "summary.csv")
df = pd.read_csv(label_path, usecols=["MOSES id", "CONFORMER id", "SMILES", "DFT TOTAL ENERGY"])
labels = df.set_index(keys=["MOSES id", "CONFORMER id"]).to_dict("index")

raw_path = p_join(self.root, "dataset_full.db")
train = HamiltonianDatabase(raw_path)
n, c = len(train), 20
step_size = int(np.ceil(n / os.cpu_count()))

fn = lambda i: read_chunk_from_db(raw_path, i * step_size, min((i + 1) * step_size, n))
fn = lambda i: read_chunk_from_db(raw_path, i * step_size, min((i + 1) * step_size, n), labels=labels)
samples = dm.parallelized(
fn, list(range(c)), n_jobs=c, progress=False, scheduler="threads"
) # don't use more than 1 job
Expand Down
Loading
Loading