diff --git a/openqdc/__init__.py b/openqdc/__init__.py index 7e0a39b..7f29724 100644 --- a/openqdc/__init__.py +++ b/openqdc/__init__.py @@ -14,6 +14,7 @@ def get_project_root(): _lazy_imports_obj = { "__version__": "openqdc._version", "BaseDataset": "openqdc.datasets.base", + # POTENTIAL "ANI1": "openqdc.datasets.potential.ani", "ANI1CCX": "openqdc.datasets.potential.ani", "ANI1X": "openqdc.datasets.potential.ani", @@ -32,12 +33,23 @@ def get_project_root(): "SolvatedPeptides": "openqdc.datasets.potential.solvated_peptides", "WaterClusters": "openqdc.datasets.potential.waterclusters3_30", "TMQM": "openqdc.datasets.potential.tmqm", - "Dummy": "openqdc.datasets.potential.dummy", "PCQM_B3LYP": "openqdc.datasets.potential.pcqm", "PCQM_PM6": "openqdc.datasets.potential.pcqm", "RevMD17": "openqdc.datasets.potential.revmd17", "Transition1X": "openqdc.datasets.potential.transition1x", "MultixcQM9": "openqdc.datasets.potential.multixcqm9", + # INTERACTION + "DES5M": "openqdc.datasets.interaction.des", + "DES370K": "openqdc.datasets.interaction.des", + "DESS66": "openqdc.datasets.interaction.des", + "DESS66x8": "openqdc.datasets.interaction.des", + "L7": "openqdc.datasets.interaction.l7", + "X40": "openqdc.datasets.interaction.x40", + "Metcalf": "openqdc.datasets.interaction.metcalf", + "Splinter": "openqdc.datasets.interaction.splinter", + # DEBUG + "Dummy": "openqdc.datasets.potential.dummy", + # ALL "AVAILABLE_DATASETS": "openqdc.datasets", "AVAILABLE_POTENTIAL_DATASETS": "openqdc.datasets.potential", "AVAILABLE_INTERACTION_DATASETS": "openqdc.datasets.interaction", @@ -75,6 +87,13 @@ def __dir__(): from ._version import __version__ # noqa from .datasets import AVAILABLE_DATASETS # noqa from .datasets.base import BaseDataset # noqa + + # INTERACTION + from .datasets.interaction.des import DES5M, DES370K, DESS66, DESS66x8 # noqa + from .datasets.interaction.l7 import L7 # noqa + from .datasets.interaction.metcalf import Metcalf # noqa + from .datasets.interaction.splinter import Splinter # noqa + from .datasets.interaction.x40 import X40 # noqa from .datasets.potential.ani import ANI1, ANI1CCX, ANI1X # noqa from .datasets.potential.comp6 import COMP6 # noqa from .datasets.potential.dummy import Dummy # noqa diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index 289e5e2..fabdfc1 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -49,11 +49,15 @@ @requires_package("torch") def to_torch(x: np.ndarray): + if isinstance(x, torch.Tensor): + return x return torch.from_numpy(x) @requires_package("jax") def to_jax(x: np.ndarray): + if isinstance(x, jnp.ndarray): + return x return jnp.array(x) @@ -166,6 +170,7 @@ def _precompute_statistics(self, overwrite_local_cache: bool = False): PerAtomFormationEnergyStats, ) self.statistics.run_calculators() # run the calculators + self._compute_average_nb_atoms() @classmethod def no_init(cls): @@ -243,6 +248,14 @@ def data_keys(self): keys.remove("forces") return keys + @property + def pkl_data_keys(self): + return list(self.pkl_data_types.keys()) + + @property + def pkl_data_types(self): + return {"name": str, "subset": str, "n_atoms": np.int32} + @property def data_types(self): return { @@ -257,8 +270,8 @@ def data_shapes(self): return { "atomic_inputs": (-1, NB_ATOMIC_FEATURES), "position_idx_range": (-1, 2), - "energies": (-1, len(self.energy_target_names)), - "forces": (-1, 3, len(self.force_target_names)), + "energies": (-1, len(self.energy_methods)), + "forces": (-1, 3, len(self.force_methods)), } def _set_units(self, en, ds): @@ -332,8 +345,14 @@ def save_preprocess(self, data_dict): # save smiles and subset local_path = p_join(self.preprocess_path, "props.pkl") - for key in ["name", "subset"]: - data_dict[key] = np.unique(data_dict[key], return_inverse=True) + + # assert that (required) pkl keys are present in data_dict + assert all([key in data_dict.keys() for key in self.pkl_data_keys]) + + # store unique and inverse indices for str-based pkl keys + for key in self.pkl_data_keys: + if self.pkl_data_types[key] == str: + data_dict[key] = np.unique(data_dict[key], return_inverse=True) with open(local_path, "wb") as f: pkl.dump(data_dict, f) @@ -369,7 +388,10 @@ def read_preprocess(self, overwrite_local_cache=False): pull_locally(filename, overwrite=overwrite_local_cache) with open(filename, "rb") as f: tmp = pkl.load(f) - for key in ["name", "subset", "n_atoms"]: + all_pkl_keys = set(tmp.keys()) - set(self.data_keys) + # assert required pkl_keys are present in all_pkl_keys + assert all([key in all_pkl_keys for key in self.pkl_data_keys]) + for key in all_pkl_keys: x = tmp.pop(key) if len(x) == 2: self.data[key] = x[0][x[1]] diff --git a/openqdc/datasets/interaction/L7.py b/openqdc/datasets/interaction/L7.py deleted file mode 100644 index ba3a15f..0000000 --- a/openqdc/datasets/interaction/L7.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -from typing import Dict, List - -import numpy as np -import yaml -from loguru import logger - -from openqdc.datasets.interaction.base import BaseInteractionDataset -from openqdc.methods import InteractionMethod, InterEnergyType -from openqdc.utils.constants import ATOM_TABLE - - -class DataItemYAMLObj: - def __init__(self, name, shortname, geometry, reference_value, setup, group, tags): - self.name = name - self.shortname = shortname - self.geometry = geometry - self.reference_value = reference_value - self.setup = setup - self.group = group - self.tags = tags - - -class DataSetYAMLObj: - def __init__(self, name, references, text, method_energy, groups_by, groups, global_setup, method_geometry=None): - self.name = name - self.references = references - self.text = text - self.method_energy = method_energy - self.method_geometry = method_geometry - self.groups_by = groups_by - self.groups = groups - self.global_setup = global_setup - - -def data_item_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): - return DataItemYAMLObj(**loader.construct_mapping(node)) - - -def dataset_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): - return DataSetYAMLObj(**loader.construct_mapping(node)) - - -def get_loader(): - """Add constructors to PyYAML loader.""" - loader = yaml.SafeLoader - loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", data_item_constructor) - loader.add_constructor("!ruby/object:ProtocolDataset::DataSetDescription", dataset_constructor) - return loader - - -class L7(BaseInteractionDataset): - """ - The L7 interaction energy dataset as described in: - - Accuracy of Quantum Chemical Methods for Large Noncovalent Complexes - Robert Sedlak, Tomasz Janowski, Michal Pitoňák, Jan Řezáč, Peter Pulay, and Pavel Hobza - Journal of Chemical Theory and Computation 2013 9 (8), 3364-3374 - DOI: 10.1021/ct400036b - - Data was downloaded and extracted from: - http://cuby4.molecular.cz/dataset_l7.html - """ - - __name__ = "L7" - __energy_unit__ = "kcal/mol" - __distance_unit__ = "ang" - __forces_unit__ = "kcal/mol/ang" - __energy_methods__ = [ - InteractionMethod.QCISDT_CBS, # "QCISD(T)/CBS", - InteractionMethod.DLPNO_CCSDT, # "DLPNO-CCSD(T)", - InteractionMethod.MP2_CBS, # "MP2/CBS", - InteractionMethod.MP2C_CBS, # "MP2C/CBS", - InteractionMethod.FIXED, # "fixed", TODO: we should remove this level of theory because unless we have a pro - InteractionMethod.DLPNO_CCSDT0, # "DLPNO-CCSD(T0)", - InteractionMethod.LNO_CCSDT, # "LNO-CCSD(T)", - InteractionMethod.FN_DMC, # "FN-DMC", - ] - - __energy_type__ = [InterEnergyType.TOTAL] * 8 - - energy_target_names = [] - - def read_raw_entries(self) -> List[Dict]: - yaml_fpath = os.path.join(self.root, "l7.yaml") - logger.info(f"Reading L7 interaction data from {self.root}") - yaml_file = open(yaml_fpath, "r") - data = [] - data_dict = yaml.load(yaml_file, Loader=get_loader()) - charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"]) - charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"]) - - for idx, item in enumerate(data_dict["items"]): - energies = [] - name = np.array([item.shortname]) - fname = item.geometry.split(":")[1] - energies.append(item.reference_value) - xyz_file = open(os.path.join(self.root, f"{fname}.xyz"), "r") - lines = list(map(lambda x: x.strip().split(), xyz_file.readlines())) - lines.pop(1) - n_atoms = np.array([int(lines[0][0])], dtype=np.int32) - n_atoms_first = np.array([int(item.setup["molecule_a"]["selection"].split("-")[1])], dtype=np.int32) - subset = np.array([item.group]) - energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())] - energies = np.array([energies], dtype=np.float32) - pos = np.array(lines[1:])[:, 1:].astype(np.float32) - elems = np.array(lines[1:])[:, 0] - atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elems]), axis=1) - natoms0 = n_atoms_first[0] - natoms1 = n_atoms[0] - natoms0 - charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) - atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) - - item = dict( - energies=energies, - subset=subset, - n_atoms=n_atoms, - n_atoms_first=n_atoms_first, - atomic_inputs=atomic_inputs, - name=name, - ) - data.append(item) - return data diff --git a/openqdc/datasets/interaction/X40.py b/openqdc/datasets/interaction/X40.py deleted file mode 100644 index 98a9d67..0000000 --- a/openqdc/datasets/interaction/X40.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -from typing import Dict, List - -import numpy as np -import yaml -from loguru import logger - -from openqdc.datasets.interaction.base import BaseInteractionDataset -from openqdc.datasets.interaction.L7 import get_loader -from openqdc.methods import InteractionMethod, InterEnergyType -from openqdc.utils.constants import ATOM_TABLE - - -class X40(BaseInteractionDataset): - """ - X40 interaction dataset of 40 dimer pairs as - introduced in the following paper: - - Benchmark Calculations of Noncovalent Interactions of Halogenated Molecules - Jan Řezáč, Kevin E. Riley, and Pavel Hobza - Journal of Chemical Theory and Computation 2012 8 (11), 4285-4292 - DOI: 10.1021/ct300647k - - Dataset retrieved and processed from: - http://cuby4.molecular.cz/dataset_x40.html - """ - - __name__ = "X40" - __energy_unit__ = "hartree" - __distance_unit__ = "ang" - __forces_unit__ = "hartree/ang" - __energy_methods__ = [ - InteractionMethod.CCSD_T_CBS, # "CCSD(T)/CBS", - InteractionMethod.MP2_CBS, # "MP2/CBS", - InteractionMethod.DCCSDT_HA_DZ, # "dCCSD(T)/haDZ", - InteractionMethod.DCCSDT_HA_TZ, # "dCCSD(T)/haTZ", - InteractionMethod.MP2_5_CBS_ADZ, # "MP2.5/CBS(aDZ)", - ] - __energy_type__ = [ - InterEnergyType.TOTAL, - ] * 5 - - energy_target_names = [] - - def read_raw_entries(self) -> List[Dict]: - yaml_fpath = os.path.join(self.root, "x40.yaml") - logger.info(f"Reading X40 interaction data from {self.root}") - yaml_file = open(yaml_fpath, "r") - data = [] - data_dict = yaml.load(yaml_file, Loader=get_loader()) - charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"]) - charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"]) - - for idx, item in enumerate(data_dict["items"]): - energies = [] - name = np.array([item.shortname]) - energies.append(float(item.reference_value)) - xyz_file = open(os.path.join(self.root, f"{item.shortname}.xyz"), "r") - lines = list(map(lambda x: x.strip().split(), xyz_file.readlines())) - setup = lines.pop(1) - n_atoms = np.array([int(lines[0][0])], dtype=np.int32) - n_atoms_first = setup[0].split("-")[1] - n_atoms_first = np.array([int(n_atoms_first)], dtype=np.int32) - subset = np.array([item.group]) - energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())] - energies = np.array([energies], dtype=np.float32) - pos = np.array(lines[1:])[:, 1:].astype(np.float32) - elems = np.array(lines[1:])[:, 0] - atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elems]), axis=1) - natoms0 = n_atoms_first[0] - natoms1 = n_atoms[0] - natoms0 - charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) - atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) - - item = dict( - energies=energies, - subset=subset, - n_atoms=n_atoms, - n_atoms_first=n_atoms_first, - atomic_inputs=atomic_inputs, - name=name, - ) - data.append(item) - return data diff --git a/openqdc/datasets/interaction/__init__.py b/openqdc/datasets/interaction/__init__.py index fa3bebd..b038802 100644 --- a/openqdc/datasets/interaction/__init__.py +++ b/openqdc/datasets/interaction/__init__.py @@ -1,12 +1,9 @@ from .base import BaseInteractionDataset # noqa -from .des5m import DES5M -from .des370k import DES370K -from .dess66 import DESS66 -from .dess66x8 import DESS66x8 -from .L7 import L7 +from .des import DES5M, DES370K, DESS66, DESS66x8 +from .l7 import L7 from .metcalf import Metcalf from .splinter import Splinter -from .X40 import X40 +from .x40 import X40 AVAILABLE_INTERACTION_DATASETS = { "des5m": DES5M, diff --git a/openqdc/datasets/interaction/_utils.py b/openqdc/datasets/interaction/_utils.py new file mode 100644 index 0000000..0d2915b --- /dev/null +++ b/openqdc/datasets/interaction/_utils.py @@ -0,0 +1,138 @@ +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from os.path import join as p_join +from typing import Dict, List, Optional + +import numpy as np +import yaml +from loguru import logger + +from openqdc.datasets.interaction.base import BaseInteractionDataset +from openqdc.methods import InterEnergyType +from openqdc.utils.constants import ATOM_TABLE + + +@dataclass +class DataSet: + description: Dict + items: List[Dict] + alternative_reference: Dict + + +@dataclass +class DataItemYAMLObj: + name: str + shortname: str + geometry: str + reference_value: float + setup: Dict + group: str + tags: str + + +@dataclass +class DataSetDescription: + name: Dict + references: str + text: str + groups_by: str + groups: List[str] + global_setup: Dict + method_energy: str + method_geometry: Optional[str] = None + + +def get_loader(): + """Add constructors to PyYAML loader.""" + + def constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode, cls): + return cls(**loader.construct_mapping(node)) + + loader = yaml.SafeLoader + + loader.add_constructor("!ruby/object:ProtocolDataset::DataSet", partial(constructor, cls=DataSet)) + loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", partial(constructor, cls=DataItemYAMLObj)) + loader.add_constructor( + "!ruby/object:ProtocolDataset::DataSetDescription", partial(constructor, cls=DataSetDescription) + ) + return loader + + +def read_xyz_file(xyz_path): + with open(xyz_path, "r") as xyz_file: # avoid not closing the file + lines = list(map(lambda x: x.strip().split(), xyz_file.readlines())) + lines.pop(1) + n_atoms = np.array([int(lines[0][0])], dtype=np.int32) + pos = np.array(lines[1:])[:, 1:].astype(np.float32) + elems = np.array(lines[1:])[:, 0] + atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elems]), axis=1) + return n_atoms, pos, atomic_nums + + +def convert_to_record(item): + return dict( + energies=item["energies"], + subset=np.array([item["subset"]]), + n_atoms=np.array([item["natoms0"] + item["natoms1"]], dtype=np.int32), + n_atoms_ptr=np.array([item["natoms0"]], dtype=np.int32), + atomic_inputs=item["atomic_inputs"], + name=item["name"], + ) + + +def build_item(item, charge0, charge1, idx, data_dict, root, filename): + datum = { + "energies": [], + } + datum["name"] = np.array([item.shortname]) + datum["energies"].append(item.reference_value) + datum["subset"] = np.array([item.group]) + datum["energies"] += [float(val[idx]) for val in list(data_dict.alternative_reference.values())] + datum["energies"] = np.array([datum["energies"]], dtype=np.float32) + n_atoms, pos, atomic_nums = read_xyz_file(p_join(root, f"{filename}.xyz")) + datum["n_atoms"] = n_atoms + datum["pos"] = pos + datum["atomic_nums"] = atomic_nums + datum["n_atoms_ptr"] = np.array([int(item.setup["molecule_a"]["selection"].split("-")[1])], dtype=np.int32) + datum["natoms0"] = datum["n_atoms_ptr"][0] + datum["natoms1"] = datum["n_atoms"][0] - datum["natoms0"] + datum["charges"] = np.expand_dims(np.array([charge0] * datum["natoms0"] + [charge1] * datum["natoms1"]), axis=1) + datum["atomic_inputs"] = np.concatenate( + (datum["atomic_nums"], datum["charges"], datum["pos"]), axis=-1, dtype=np.float32 + ) + return datum + + +class YamlDataset(BaseInteractionDataset, ABC): + __name__ = "l7" + __energy_unit__ = "kcal/mol" + __distance_unit__ = "ang" + __forces_unit__ = "kcal/mol/ang" + energy_target_names = [] + __energy_methods__ = [] + __energy_type__ = [InterEnergyType.TOTAL] * len(__energy_methods__) + + @property + def yaml_path(self): + return os.path.join(self.root, self.__name__ + ".yaml") + + def read_raw_entries(self) -> List[Dict]: + yaml_fpath = self.yaml_path + logger.info(f"Reading {self.__name__} interaction data from {self.root}") + with open(yaml_fpath, "r") as yaml_file: + data_dict = yaml.load(yaml_file, Loader=get_loader()) + data = [] + charge0 = int(data_dict.description.global_setup["molecule_a"]["charge"]) + charge1 = int(data_dict.description.global_setup["molecule_b"]["charge"]) + + for idx, item in enumerate(data_dict.items): + tmp_item = build_item(item, charge0, charge1, idx, data_dict, self.root, self._process_name(item)) + item = convert_to_record(tmp_item) + data.append(item) + return data + + @abstractmethod + def _process_name(self, item): + raise NotImplementedError diff --git a/openqdc/datasets/interaction/base.py b/openqdc/datasets/interaction/base.py index 23527f8..2ce5481 100644 --- a/openqdc/datasets/interaction/base.py +++ b/openqdc/datasets/interaction/base.py @@ -1,52 +1,26 @@ import os -import pickle as pkl from os.path import join as p_join -from typing import Dict, List, Optional +from typing import Optional import numpy as np from ase.io.extxyz import write_extxyz -from loguru import logger from sklearn.utils import Bunch from openqdc.datasets.base import BaseDataset -from openqdc.utils.constants import MAX_CHARGE, NB_ATOMIC_FEATURES -from openqdc.utils.io import pull_locally, push_remote, to_atoms +from openqdc.utils.constants import MAX_CHARGE +from openqdc.utils.io import to_atoms class BaseInteractionDataset(BaseDataset): __energy_type__ = [] - def collate_list(self, list_entries: List[Dict]): - # concatenate entries - res = { - key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) - for key in list_entries[0] - if not isinstance(list_entries[0][key], dict) - } - - csum = np.cumsum(res.get("n_atoms")) - x = np.zeros((csum.shape[0], 2), dtype=np.int32) - x[1:, 0], x[:, 1] = csum[:-1], csum - res["position_idx_range"] = x - - return res - - @property - def data_shapes(self): - return { - "atomic_inputs": (-1, NB_ATOMIC_FEATURES), - "position_idx_range": (-1, 2), - "energies": (-1, len(self.__energy_methods__)), - "forces": (-1, 3, len(self.force_target_names)), - } - @property - def data_types(self): + def pkl_data_types(self): return { - "atomic_inputs": np.float32, - "position_idx_range": np.int32, - "energies": np.float32, - "forces": np.float32, + "name": str, + "subset": str, + "n_atoms": np.int32, + "n_atoms_ptr": np.int32, } def __getitem__(self, idx: int): @@ -61,7 +35,7 @@ def __getitem__(self, idx: int): ) name = self.__smiles_converter__(self.data["name"][idx]) subset = self.data["subset"][idx] - n_atoms_first = self.data["n_atoms_first"][idx] + n_atoms_ptr = self.data["n_atoms_ptr"][idx] forces = None if "forces" in self.data: @@ -78,7 +52,7 @@ def __getitem__(self, idx: int): name=name, subset=subset, forces=forces, - n_atoms_first=n_atoms_first, + n_atoms_ptr=n_atoms_ptr, ) if self.transform is not None: @@ -86,60 +60,10 @@ def __getitem__(self, idx: int): return bunch - def save_preprocess(self, data_dict): - # save memmaps - logger.info("Preprocessing data and saving it to cache.") - for key in self.data_keys: - local_path = p_join(self.preprocess_path, f"{key}.mmap") - out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape) - out[:] = data_dict.pop(key)[:] - out.flush() - push_remote(local_path, overwrite=True) - - # save all other keys in props.pkl - local_path = p_join(self.preprocess_path, "props.pkl") - for key in data_dict: - if key not in self.data_keys: - x = data_dict[key] - x[x == None] = -1 # noqa - data_dict[key] = np.unique(x, return_inverse=True) - - with open(local_path, "wb") as f: - pkl.dump(data_dict, f) - push_remote(local_path, overwrite=True) - - def read_preprocess(self, overwrite_local_cache=False): - logger.info("Reading preprocessed data.") - logger.info( - f"Dataset {self.__name__} with the following units:\n\ - Energy: {self.energy_unit},\n\ - Distance: {self.distance_unit},\n\ - Forces: {self.force_unit if self.__force_methods__ else 'None'}" - ) - self.data = {} - for key in self.data_keys: - filename = p_join(self.preprocess_path, f"{key}.mmap") - pull_locally(filename, overwrite=overwrite_local_cache) - self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(self.data_shapes[key]) - - filename = p_join(self.preprocess_path, "props.pkl") - pull_locally(filename, overwrite=overwrite_local_cache) - with open(filename, "rb") as f: - tmp = pkl.load(f) - for key in set(tmp.keys()) - set(self.data_keys): - x = tmp.pop(key) - if len(x) == 2: - self.data[key] = x[0][x[1]] - else: - self.data[key] = x - - for key in self.data: - logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}") - def get_ase_atoms(self, idx: int): entry = self[idx] at = to_atoms(entry["positions"], entry["atomic_numbers"]) - at.info["n_atoms"] = entry["n_atoms_first"] + at.info["n_atoms"] = entry["n_atoms_ptr"] return at def save_xyz(self, idx: int, path: Optional[str] = None): diff --git a/openqdc/datasets/interaction/des.py b/openqdc/datasets/interaction/des.py new file mode 100644 index 0000000..a292fc3 --- /dev/null +++ b/openqdc/datasets/interaction/des.py @@ -0,0 +1,223 @@ +import os +from abc import ABC, abstractmethod +from typing import Dict, List + +import numpy as np +import pandas as pd +from loguru import logger +from tqdm import tqdm + +from openqdc.datasets.interaction.base import BaseInteractionDataset +from openqdc.methods import InteractionMethod, InterEnergyType +from openqdc.utils.constants import ATOM_TABLE +from openqdc.utils.molecule import molecule_groups + + +def parse_des_df(row, energy_target_names): + smiles0, smiles1 = row["smiles0"], row["smiles1"] + charge0, charge1 = row["charge0"], row["charge1"] + natoms0, natoms1 = row["natoms0"], row["natoms1"] + pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3) + elements = row["elements"].split() + atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1) + charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) + atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) + energies = np.array(row[energy_target_names].values).astype(np.float32)[None, :] + name = np.array([smiles0 + "." + smiles1]) + return { + "energies": energies, + "n_atoms": np.array([natoms0 + natoms1], dtype=np.int32), + "name": name, + "atomic_inputs": atomic_inputs, + "charges": charges, + "atomic_nums": atomic_nums, + "elements": elements, + "natoms0": natoms0, + "natoms1": natoms1, + "smiles0": smiles0, + "smiles1": smiles1, + "charge0": charge0, + "charge1": charge1, + } + + +def create_subset(smiles0, smiles1): + subsets = [] + for smiles in [smiles0, smiles1]: + found = False + for functional_group, smiles_set in molecule_groups.items(): + if smiles in smiles_set: + subsets.append(functional_group) + found = True + if not found: + logger.info(f"molecule group lookup failed for {smiles}") + return subsets + + +def convert_to_record(item): + return dict( + energies=item["energies"], + subset=np.array([item["subset"]]), + n_atoms=np.array([item["natoms0"] + item["natoms1"]], dtype=np.int32), + n_atoms_ptr=np.array([item["natoms0"]], dtype=np.int32), + atomic_inputs=item["atomic_inputs"], + name=item["name"], + ) + + +class IDES(ABC): + @abstractmethod + def _create_subsets(self, **kwargs): + raise NotImplementedError + + +class DES370K(BaseInteractionDataset, IDES): + """ + DE Shaw Research interaction energy of over 370K + small molecule dimers as described in the paper: + + Quantum chemical benchmark databases of gold-standard dimer interaction energies. + Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. + Sci Data 8, 55 (2021). + https://doi.org/10.1038/s41597-021-00833-x + """ + + __name__ = "des370k_interaction" + __filename__ = "DES370K.csv" + __energy_unit__ = "kcal/mol" + __distance_unit__ = "ang" + __forces_unit__ = "kcal/mol/ang" + __energy_methods__ = [ + InteractionMethod.MP2_CC_PVDZ, + InteractionMethod.MP2_CC_PVQZ, + InteractionMethod.MP2_CC_PVTZ, + InteractionMethod.MP2_CBS, + InteractionMethod.CCSD_T_CC_PVDZ, + InteractionMethod.CCSD_T_CBS, + InteractionMethod.CCSD_T_NN, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + InteractionMethod.SAPT0_AUG_CC_PWCVXZ, + ] + + __energy_type__ = [ + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.TOTAL, + InterEnergyType.ES, + InterEnergyType.EX, + InterEnergyType.EX_S2, + InterEnergyType.IND, + InterEnergyType.EX_IND, + InterEnergyType.DISP, + InterEnergyType.EX_DISP_OS, + InterEnergyType.EX_DISP_SS, + InterEnergyType.DELTA_HF, + ] + + energy_target_names = [ + "cc_MP2_all", + "qz_MP2_all", + "tz_MP2_all", + "cbs_MP2_all", + "cc_CCSD(T)_all", + "cbs_CCSD(T)_all", + "nn_CCSD(T)_all", + "sapt_all", + "sapt_es", + "sapt_ex", + "sapt_exs2", + "sapt_ind", + "sapt_exind", + "sapt_disp", + "sapt_exdisp_os", + "sapt_exdisp_ss", + "sapt_delta_HF", + ] + + @property + def csv_path(self): + return os.path.join(self.root, self.__filename__) + + def _create_subsets(self, **kwargs): + return create_subset(kwargs["smiles0"], kwargs["smiles1"]) + + def read_raw_entries(self) -> List[Dict]: + filepath = self.csv_path + logger.info(f"Reading {self.__name__} interaction data from {filepath}") + df = pd.read_csv(filepath) + data = [] + for idx, row in tqdm(df.iterrows(), total=df.shape[0]): + item = parse_des_df(row, self.energy_target_names) + item["subset"] = self._create_subsets(**item) + item = convert_to_record(item) + data.append(item) + return data + + +class DES5M(DES370K): + """ + DE Shaw Research interaction energy calculations for + over 5M small molecule dimers as described in the paper: + + Quantum chemical benchmark databases of gold-standard dimer interaction energies. + Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. + Sci Data 8, 55 (2021). + https://doi.org/10.1038/s41597-021-00833-x + """ + + __name__ = "des5m_interaction" + __filename__ = "DES5M.csv" + + +class DESS66(DES370K): + """ + DE Shaw Research interaction energy + estimates of all 66 conformers from + the original S66 dataset as described + in the paper: + + Quantum chemical benchmark databases of gold-standard dimer interaction energies. + Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. + Sci Data 8, 55 (2021). + https://doi.org/10.1038/s41597-021-00833-x + + Data was downloaded from Zenodo: + https://zenodo.org/records/5676284 + """ + + __name__ = "des_s66" + __filename__ = "DESS66.csv" + + +class DESS66x8(DESS66): + """ + DE Shaw Research interaction energy + estimates of all 528 conformers from + the original S66x8 dataset as described + in the paper: + + Quantum chemical benchmark databases of gold-standard dimer interaction energies. + Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. + Sci Data 8, 55 (2021). + https://doi.org/10.1038/s41597-021-00833-x + + Data was downloaded from Zenodo: + + https://zenodo.org/records/5676284 + """ + + __name__ = "des_s66x8" + __filename__ = "DESS66x8.csv" diff --git a/openqdc/datasets/interaction/des370k.py b/openqdc/datasets/interaction/des370k.py deleted file mode 100644 index 250d42d..0000000 --- a/openqdc/datasets/interaction/des370k.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -from typing import Dict, List - -import numpy as np -import pandas as pd -from loguru import logger -from tqdm import tqdm - -from openqdc.datasets.interaction.base import BaseInteractionDataset -from openqdc.methods import InteractionMethod, InterEnergyType -from openqdc.utils.constants import ATOM_TABLE -from openqdc.utils.io import get_local_cache -from openqdc.utils.molecule import molecule_groups - - -class DES370K(BaseInteractionDataset): - """ - DE Shaw Research interaction energy of over 370K - small molecule dimers as described in the paper: - - Quantum chemical benchmark databases of gold-standard dimer interaction energies. - Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. - Sci Data 8, 55 (2021). - https://doi.org/10.1038/s41597-021-00833-x - """ - - __name__ = "des370k_interaction" - __energy_unit__ = "kcal/mol" - __distance_unit__ = "ang" - __forces_unit__ = "kcal/mol/ang" - __energy_methods__ = [ - InteractionMethod.MP2_CC_PVDZ, - InteractionMethod.MP2_CC_PVQZ, - InteractionMethod.MP2_CC_PVTZ, - InteractionMethod.MP2_CBS, - InteractionMethod.CCSD_T_CC_PVDZ, - InteractionMethod.CCSD_T_CBS, - InteractionMethod.CCSD_T_NN, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - ] - - __energy_type__ = [ - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.ES, - InterEnergyType.EX, - InterEnergyType.EX_S2, - InterEnergyType.IND, - InterEnergyType.EX_IND, - InterEnergyType.DISP, - InterEnergyType.EX_DISP_OS, - InterEnergyType.EX_DISP_SS, - InterEnergyType.DELTA_HF, - ] - - energy_target_names = [ - "cc_MP2_all", - "qz_MP2_all", - "tz_MP2_all", - "cbs_MP2_all", - "cc_CCSD(T)_all", - "cbs_CCSD(T)_all", - "nn_CCSD(T)_all", - "sapt_all", - "sapt_es", - "sapt_ex", - "sapt_exs2", - "sapt_ind", - "sapt_exind", - "sapt_disp", - "sapt_exdisp_os", - "sapt_exdisp_ss", - "sapt_delta_HF", - ] - - _filename = "DES370K.csv" - _name = "des370k_interaction" - - @classmethod - def _root(cls): - return os.path.join(get_local_cache(), cls._name) - - @classmethod - def _read_raw_entries(cls) -> List[Dict]: - filepath = os.path.join(cls._root(), cls._filename) - logger.info(f"Reading {cls._name} interaction data from {filepath}") - df = pd.read_csv(filepath) - data = [] - for idx, row in tqdm(df.iterrows(), total=df.shape[0]): - smiles0, smiles1 = row["smiles0"], row["smiles1"] - charge0, charge1 = row["charge0"], row["charge1"] - natoms0, natoms1 = row["natoms0"], row["natoms1"] - pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3) - - elements = row["elements"].split() - - atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1) - - charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) - - atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) - - energies = np.array(row[cls.energy_target_names].values).astype(np.float32)[None, :] - - name = np.array([smiles0 + "." + smiles1]) - - subsets = [] - for smiles in [smiles0, smiles1]: - found = False - for functional_group, smiles_set in molecule_groups.items(): - if smiles in smiles_set: - subsets.append(functional_group) - found = True - if not found: - logger.info(f"molecule group lookup failed for {smiles}") - - item = dict( - energies=energies, - subset=np.array([subsets]), - n_atoms=np.array([natoms0 + natoms1], dtype=np.int32), - n_atoms_first=np.array([natoms0], dtype=np.int32), - atomic_inputs=atomic_inputs, - name=name, - ) - data.append(item) - return data - - def read_raw_entries(self) -> List[Dict]: - return DES370K._read_raw_entries() diff --git a/openqdc/datasets/interaction/des5m.py b/openqdc/datasets/interaction/des5m.py deleted file mode 100644 index 979909c..0000000 --- a/openqdc/datasets/interaction/des5m.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Dict, List - -from openqdc.datasets.interaction.des370k import DES370K -from openqdc.methods import InteractionMethod, InterEnergyType - - -class DES5M(DES370K): - """ - DE Shaw Research interaction energy calculations for - over 5M small molecule dimers as described in the paper: - - Quantum chemical benchmark databases of gold-standard dimer interaction energies. - Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. - Sci Data 8, 55 (2021). - https://doi.org/10.1038/s41597-021-00833-x - """ - - __name__ = "des5m_interaction" - __energy_methods__ = [ - InteractionMethod.MP2_CC_PVQZ, - InteractionMethod.MP2_CC_PVTZ, - InteractionMethod.MP2_CBS, - InteractionMethod.CCSD_T_NN, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - ] - - __energy_type__ = [ - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.ES, - InterEnergyType.EX, - InterEnergyType.EX_S2, - InterEnergyType.IND, - InterEnergyType.EX_IND, - InterEnergyType.DISP, - InterEnergyType.EX_DISP_OS, - InterEnergyType.EX_DISP_SS, - InterEnergyType.DELTA_HF, - ] - - energy_target_names = [ - "qz_MP2_all", - "tz_MP2_all", - "cbs_MP2_all", - "nn_CCSD(T)_all", - "sapt_all", - "sapt_es", - "sapt_ex", - "sapt_exs2", - "sapt_ind", - "sapt_exind", - "sapt_disp", - "sapt_exdisp_os", - "sapt_exdisp_ss", - "sapt_delta_HF", - ] - - _filename = "DES5M.csv" - _name = "des5m_interaction" - - __energy_unit__ = "kcal/mol" - __distance_unit__ = "ang" - __forces_unit__ = "kcal/mol/ang" - - def read_raw_entries(self) -> List[Dict]: - return DES5M._read_raw_entries() diff --git a/openqdc/datasets/interaction/dess66.py b/openqdc/datasets/interaction/dess66.py deleted file mode 100644 index c10811b..0000000 --- a/openqdc/datasets/interaction/dess66.py +++ /dev/null @@ -1,128 +0,0 @@ -import os -from typing import Dict, List - -import numpy as np -import pandas as pd -from loguru import logger -from tqdm import tqdm - -from openqdc.datasets.interaction.base import BaseInteractionDataset -from openqdc.methods import InteractionMethod, InterEnergyType -from openqdc.utils.constants import ATOM_TABLE - - -class DESS66(BaseInteractionDataset): - """ - DE Shaw Research interaction energy - estimates of all 66 conformers from - the original S66 dataset as described - in the paper: - - Quantum chemical benchmark databases of gold-standard dimer interaction energies. - Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. - Sci Data 8, 55 (2021). - https://doi.org/10.1038/s41597-021-00833-x - - Data was downloaded from Zenodo: - https://zenodo.org/records/5676284 - """ - - __name__ = "des_s66" - __energy_unit__ = "kcal/mol" - __distance_unit__ = "ang" - __forces_unit__ = "kcal/mol/ang" - __energy_methods__ = [ - InteractionMethod.MP2_CC_PVDZ, - InteractionMethod.MP2_CC_PVQZ, - InteractionMethod.MP2_CC_PVTZ, - InteractionMethod.MP2_CBS, - InteractionMethod.CCSD_T_CC_PVDZ, - InteractionMethod.CCSD_T_CBS, - InteractionMethod.CCSD_T_NN, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - ] - - __energy_type__ = [ - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.ES, - InterEnergyType.EX, - InterEnergyType.EX_S2, - InterEnergyType.IND, - InterEnergyType.EX_IND, - InterEnergyType.DISP, - InterEnergyType.EX_DISP_OS, - InterEnergyType.EX_DISP_SS, - InterEnergyType.DELTA_HF, - ] - - energy_target_names = [ - "cc_MP2_all", - "qz_MP2_all", - "tz_MP2_all", - "cbs_MP2_all", - "cc_CCSD(T)_all", - "cbs_CCSD(T)_all", - "nn_CCSD(T)_all", - "sapt_all", - "sapt_es", - "sapt_ex", - "sapt_exs2", - "sapt_ind", - "sapt_exind", - "sapt_disp", - "sapt_exdisp_os", - "sapt_exdisp_ss", - "sapt_delta_HF", - ] - - def read_raw_entries(self) -> List[Dict]: - self.filepath = os.path.join(self.root, "DESS66.csv") - logger.info(f"Reading DESS66 interaction data from {self.filepath}") - df = pd.read_csv(self.filepath) - data = [] - for idx, row in tqdm(df.iterrows(), total=df.shape[0]): - smiles0, smiles1 = row["smiles0"], row["smiles1"] - charge0, charge1 = row["charge0"], row["charge1"] - natoms0, natoms1 = row["natoms0"], row["natoms1"] - pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3) - - elements = row["elements"].split() - - atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1) - - charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) - - atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) - - energies = np.array(row[self.energy_target_names].values).astype(np.float32)[None, :] - - name = np.array([smiles0 + "." + smiles1]) - - subset = row["system_name"] - - item = dict( - energies=energies, - subset=np.array([subset]), - n_atoms=np.array([natoms0 + natoms1], dtype=np.int32), - n_atoms_first=np.array([natoms0], dtype=np.int32), - atomic_inputs=atomic_inputs, - name=name, - ) - data.append(item) - return data diff --git a/openqdc/datasets/interaction/dess66x8.py b/openqdc/datasets/interaction/dess66x8.py deleted file mode 100644 index 709620a..0000000 --- a/openqdc/datasets/interaction/dess66x8.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -from typing import Dict, List - -import numpy as np -import pandas as pd -from loguru import logger -from tqdm import tqdm - -from openqdc.datasets.interaction.base import BaseInteractionDataset -from openqdc.methods import InteractionMethod, InterEnergyType -from openqdc.utils.constants import ATOM_TABLE - - -class DESS66x8(BaseInteractionDataset): - """ - DE Shaw Research interaction energy - estimates of all 528 conformers from - the original S66x8 dataset as described - in the paper: - - Quantum chemical benchmark databases of gold-standard dimer interaction energies. - Donchev, A.G., Taube, A.G., Decolvenaere, E. et al. - Sci Data 8, 55 (2021). - https://doi.org/10.1038/s41597-021-00833-x - - Data was downloaded from Zenodo: - - https://zenodo.org/records/5676284 - """ - - __name__ = "des_s66x8" - __energy_unit__ = "kcal/mol" - __distance_unit__ = "ang" - __forces_unit__ = "kcal/mol/ang" - __energy_methods__ = [ - InteractionMethod.MP2_CC_PVDZ, - InteractionMethod.MP2_CC_PVQZ, - InteractionMethod.MP2_CC_PVTZ, - InteractionMethod.MP2_CBS, - InteractionMethod.CCSD_T_CC_PVDZ, - InteractionMethod.CCSD_T_CBS, - InteractionMethod.CCSD_T_NN, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - InteractionMethod.SAPT0_AUG_CC_PWCVXZ, - ] - - __energy_type__ = [ - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.TOTAL, - InterEnergyType.ES, - InterEnergyType.EX, - InterEnergyType.EX_S2, - InterEnergyType.IND, - InterEnergyType.EX_IND, - InterEnergyType.DISP, - InterEnergyType.EX_DISP_OS, - InterEnergyType.EX_DISP_SS, - InterEnergyType.DELTA_HF, - ] - - energy_target_names = [ - "cc_MP2_all", - "qz_MP2_all", - "tz_MP2_all", - "cbs_MP2_all", - "cc_CCSD(T)_all", - "cbs_CCSD(T)_all", - "nn_CCSD(T)_all", - "sapt_all", - "sapt_es", - "sapt_ex", - "sapt_exs2", - "sapt_ind", - "sapt_exind", - "sapt_disp", - "sapt_exdisp_os", - "sapt_exdisp_ss", - "sapt_delta_HF", - ] - - def read_raw_entries(self) -> List[Dict]: - self.filepath = os.path.join(self.root, "DESS66x8.csv") - logger.info(f"Reading DESS66x8 interaction data from {self.filepath}") - df = pd.read_csv(self.filepath) - data = [] - for idx, row in tqdm(df.iterrows(), total=df.shape[0]): - smiles0, smiles1 = row["smiles0"], row["smiles1"] - charge0, charge1 = row["charge0"], row["charge1"] - natoms0, natoms1 = row["natoms0"], row["natoms1"] - pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3) - - elements = row["elements"].split() - - atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1) - - charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) - - atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) - - energies = np.array(row[self.energy_target_names].values).astype(np.float32)[None, :] - - name = np.array([smiles0 + "." + smiles1]) - - subset = row["system_name"] - - item = dict( - energies=energies, - subset=np.array([subset]), - n_atoms=np.array([natoms0 + natoms1], dtype=np.int32), - n_atoms_first=np.array([natoms0], dtype=np.int32), - atomic_inputs=atomic_inputs, - name=name, - ) - data.append(item) - return data diff --git a/openqdc/datasets/interaction/dummy.py b/openqdc/datasets/interaction/dummy.py new file mode 100644 index 0000000..7f19154 --- /dev/null +++ b/openqdc/datasets/interaction/dummy.py @@ -0,0 +1,71 @@ +import numpy as np + +from openqdc.datasets.interaction.base import BaseInteractionDataset +from openqdc.methods import InteractionMethod + + +class DummyInteraction(BaseInteractionDataset): + """ + Dummy Interaction Dataset for Testing + """ + + __name__ = "dummy_interaction" + __energy_methods__ = [InteractionMethod.SAPT0_AUG_CC_PVDDZ, InteractionMethod.CCSD_T_CC_PVDZ] + __force_mask__ = [False, False] + __energy_unit__ = "kcal/mol" + __distance_unit__ = "ang" + __forces_unit__ = "kcal/mol/ang" + + energy_target_names = [f"energy{i}" for i in range(len(__energy_methods__))] + + __isolated_atom_energies__ = [] + __average_n_atoms__ = None + + def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None: + self.setup_dummy() + return super()._post_init(overwrite_local_cache, energy_unit, distance_unit) + + def setup_dummy(self): + n_atoms = np.array([np.random.randint(10, 30) for _ in range(len(self))]) + n_atoms_ptr = np.array([np.random.randint(1, 10) for _ in range(len(self))]) + position_idx_range = np.concatenate([[0], np.cumsum(n_atoms)]).repeat(2)[1:-1].reshape(-1, 2) + atomic_inputs = np.concatenate( + [ + np.concatenate( + [ + # z, c, x, y, z + np.random.randint(1, 100, size=(size, 1)), + np.random.randint(-1, 2, size=(size, 1)), + np.random.randn(size, 3), + ], + axis=1, + ) + for size in n_atoms + ], + axis=0, + ) # (sum(n_atoms), 5) + name = [f"dummy_{i}" for i in range(len(self))] + subset = ["dummy" for i in range(len(self))] + energies = np.random.rand(len(self), len(self.energy_methods)) + self.data = dict( + n_atoms=n_atoms, + position_idx_range=position_idx_range, + name=name, + atomic_inputs=atomic_inputs, + subset=subset, + energies=energies, + n_atoms_ptr=n_atoms_ptr, + ) + self.__average_nb_atoms__ = self.data["n_atoms"].mean() + + def read_preprocess(self, overwrite_local_cache=False): + return + + def is_preprocessed(self): + return True + + def read_raw_entries(self): + pass + + def __len__(self): + return 9999 diff --git a/openqdc/datasets/interaction/l7.py b/openqdc/datasets/interaction/l7.py new file mode 100644 index 0000000..22e3141 --- /dev/null +++ b/openqdc/datasets/interaction/l7.py @@ -0,0 +1,32 @@ +from openqdc.methods import InteractionMethod + +from ._utils import YamlDataset + + +class L7(YamlDataset): + """ + The L7 interaction energy dataset as described in: + + Accuracy of Quantum Chemical Methods for Large Noncovalent Complexes + Robert Sedlak, Tomasz Janowski, Michal Pitoňák, Jan Řezáč, Peter Pulay, and Pavel Hobza + Journal of Chemical Theory and Computation 2013 9 (8), 3364-3374 + DOI: 10.1021/ct400036b + + Data was downloaded and extracted from: + http://cuby4.molecular.cz/dataset_l7.html + """ + + __name__ = "l7" + __energy_methods__ = [ + InteractionMethod.QCISDT_CBS, # "QCISD(T)/CBS", + InteractionMethod.DLPNO_CCSDT, # "DLPNO-CCSD(T)", + InteractionMethod.MP2_CBS, # "MP2/CBS", + InteractionMethod.MP2C_CBS, # "MP2C/CBS", + InteractionMethod.FIXED, # "fixed", TODO: we should remove this level of theory because unless we have a pro + InteractionMethod.DLPNO_CCSDT0, # "DLPNO-CCSD(T0)", + InteractionMethod.LNO_CCSDT, # "LNO-CCSD(T)", + InteractionMethod.FN_DMC, # "FN-DMC", + ] + + def _process_name(self, item): + return item.geometry.split(":")[1] diff --git a/openqdc/datasets/interaction/metcalf.py b/openqdc/datasets/interaction/metcalf.py index 819d5dc..60298c4 100644 --- a/openqdc/datasets/interaction/metcalf.py +++ b/openqdc/datasets/interaction/metcalf.py @@ -1,12 +1,86 @@ import os +from glob import glob +from io import StringIO +from os.path import join as p_join from typing import Dict, List import numpy as np +from loguru import logger +from tqdm import tqdm from openqdc.datasets.interaction.base import BaseInteractionDataset from openqdc.methods import InteractionMethod, InterEnergyType +from openqdc.raws.config_factory import decompress_tar_gz from openqdc.utils.constants import ATOM_TABLE +EXPECTED_TAR_FILES = { + "train": [ + "TRAINING-2073-ssi-neutral.tar.gz", + "TRAINING-2610-donors-perturbed.tar.gz", + "TRAINING-4795-acceptors-perturbed.tar.gz", + ], + "val": ["VALIDATION-125-donors.tar.gz", "VALIDATION-254-acceptors.tar.gz"], + "test": [ + "TEST-Acc--3-methylbutan-2-one_Don--NMe-acetamide-PLDB.tar.gz", + "TEST-Acc--Cyclohexanone_Don--NMe-acetamide-PLDB.tar.gz", + "TEST-Acc--Isoquinolone_NMe-acetamide.tar.gz", + "TEST-Acc--NMe-acetamide_Don--Aniline-CSD.tar.gz", + "TEST-Acc--NMe-acetamide_Don--Aniline-PLDB.tar.gz", + "TEST-Acc--NMe-acetamide_Don--N-isopropylacetamide-PLDB.tar.gz", + "TEST-Acc--NMe-acetamide_Don--N-phenylbenzamide-PLDB.tar.gz", + "TEST-Acc--NMe-acetamide_Don--Naphthalene-1H-PLDB.tar.gz", + "TEST-Acc--NMe-acetamide_Don--Uracil-PLDB.tar.gz", + "TEST-Acc--Tetrahydro-2H-pyran-2-one_NMe-acetamide-PLDB.tar.gz", + "TEST-NMe-acetamide_Don--Benzimidazole-PLDB.tar.gz", + ], +} + + +def extract_raw_tar_gz(folder): + logger.info(f"Extracting all tar.gz files in {folder}") + for subset in EXPECTED_TAR_FILES: + for tar_file in EXPECTED_TAR_FILES[subset]: + tar_file_path = p_join(folder, tar_file) + try: + decompress_tar_gz(tar_file_path) + except FileNotFoundError as e: + raise FileNotFoundError(f"File {tar_file_path} not found") from e + + +def content_to_xyz(content, subset): + try: + num_atoms = np.array([int(content.split("\n")[0])]) + tmp = content.split("\n")[1].split(",") + name = tmp[0] + e = tmp[1:-1] + except Exception as e: + logger.warning(f"Encountered exception in {content} : {e}") + return None + + s = StringIO(content) + d = np.loadtxt(s, skiprows=2, dtype="str") + z, positions = d[:, 0], d[:, 1:].astype(np.float32) + z = np.array([ATOM_TABLE.GetAtomicNumber(s) for s in z]) + xs = np.stack((z, np.zeros_like(z)), axis=-1) + + item = dict( + n_atoms=num_atoms, + subset=np.array([subset]), + energies=e, + atomic_inputs=np.concatenate((xs, positions), axis=-1, dtype=np.float32), + name=np.array([name]), + n_atoms_ptr=np.array([-1]), + ) + + return item + + +def read_xyz(fname, subset): + with open(fname, "r") as f: + contents = f.read().split("\n\n") + res = [content_to_xyz(content, subset) for content in tqdm(contents)] + return res + class Metcalf(BaseInteractionDataset): """ @@ -53,36 +127,9 @@ class Metcalf(BaseInteractionDataset): ] def read_raw_entries(self) -> List[Dict]: + # extract in folders + extract_raw_tar_gz(self.root) data = [] - for dirname in os.listdir(self.root): - xyz_dir = os.path.join(self.root, dirname) - if not os.path.isdir(xyz_dir): - continue - subset = np.array([dirname.split("-")[0].lower()]) # training, validation, or test - for filename in os.listdir(xyz_dir): - if not filename.endswith(".xyz"): - continue - lines = list(map(lambda x: x.strip(), open(os.path.join(xyz_dir, filename), "r").readlines())) - line_two = lines[1].split(",") - energies = np.array([line_two[1:6]], dtype=np.float32) - num_atoms = np.array([int(lines[0])]) - - elem_xyz = np.array([x.split() for x in lines[2:]]) - elements = elem_xyz[:, 0] - xyz = elem_xyz[:, 1:].astype(np.float32) - atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1) - charges = np.expand_dims(np.array([0] * num_atoms[0]), axis=1) - - atomic_inputs = np.concatenate((atomic_nums, charges, xyz), axis=-1, dtype=np.float32) - - item = dict( - n_atoms=num_atoms, - subset=subset, - energies=energies, - positions=xyz, - atomic_inputs=atomic_inputs, - name=np.array([""]), - n_atoms_first=np.array([-1]), - ) - data.append(item) + for filename in glob(self.root + f"{os.sep}*.xyz"): + data.append(read_xyz(filename, self.__name__)) return data diff --git a/openqdc/datasets/interaction/splinter.py b/openqdc/datasets/interaction/splinter.py index 07db952..a793624 100644 --- a/openqdc/datasets/interaction/splinter.py +++ b/openqdc/datasets/interaction/splinter.py @@ -134,15 +134,15 @@ def read_raw_entries(self) -> List[Dict]: index, _, ) = metadata[0].split("_") - r, theta_P, tau_P, theta_L, tau_L, tau_PL = [None] * 6 + r, theta_P, tau_P, theta_L, tau_L, tau_PL = [np.nan] * 6 energies = np.array([list(map(float, metadata[4:-1]))]).astype(np.float32) - n_atoms_first = np.array([int(metadata[-1])], dtype=np.int32) + n_atoms_ptr = np.array([int(metadata[-1])], dtype=np.int32) total_charge, charge0, charge1 = list(map(int, metadata[1:4])) lines = list(map(lambda x: x.split(), lines[2:])) pos = np.array(lines)[:, 1:].astype(np.float32) elems = np.array(lines)[:, 0] atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elems]), axis=1) - natoms0 = n_atoms_first[0] + natoms0 = n_atoms_ptr[0] natoms1 = n_atoms[0] - natoms0 charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1) atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32) @@ -152,7 +152,7 @@ def read_raw_entries(self) -> List[Dict]: energies=energies, subset=subset, n_atoms=n_atoms, - n_atoms_first=n_atoms_first, + n_atoms_ptr=n_atoms_ptr, atomic_inputs=atomic_inputs, protein_monomer_name=np.array([protein_monomer_name]), protein_interaction_site_type=np.array([protein_interaction_site_type]), diff --git a/openqdc/datasets/interaction/x40.py b/openqdc/datasets/interaction/x40.py new file mode 100644 index 0000000..1b5148c --- /dev/null +++ b/openqdc/datasets/interaction/x40.py @@ -0,0 +1,29 @@ +from openqdc.datasets.interaction._utils import YamlDataset +from openqdc.methods import InteractionMethod + + +class X40(YamlDataset): + """ + X40 interaction dataset of 40 dimer pairs as + introduced in the following paper: + + Benchmark Calculations of Noncovalent Interactions of Halogenated Molecules + Jan Řezáč, Kevin E. Riley, and Pavel Hobza + Journal of Chemical Theory and Computation 2012 8 (11), 4285-4292 + DOI: 10.1021/ct300647k + + Dataset retrieved and processed from: + http://cuby4.molecular.cz/dataset_x40.html + """ + + __name__ = "x40" + __energy_methods__ = [ + InteractionMethod.CCSD_T_CBS, # "CCSD(T)/CBS", + InteractionMethod.MP2_CBS, # "MP2/CBS", + InteractionMethod.DCCSDT_HA_DZ, # "dCCSD(T)/haDZ", + InteractionMethod.DCCSDT_HA_TZ, # "dCCSD(T)/haTZ", + InteractionMethod.MP2_5_CBS_ADZ, # "MP2.5/CBS(aDZ)", + ] + + def _process_name(self, item): + return item.shortname diff --git a/openqdc/datasets/potential/dummy.py b/openqdc/datasets/potential/dummy.py index 1c7a61c..b485d40 100644 --- a/openqdc/datasets/potential/dummy.py +++ b/openqdc/datasets/potential/dummy.py @@ -14,7 +14,7 @@ class Dummy(BaseDataset): """ __name__ = "dummy" - __energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6, PotentialMethod.GFN2_XTB] + __energy_methods__ = [PotentialMethod.GFN2_XTB, PotentialMethod.WB97X_D_DEF2_SVP, PotentialMethod.GFN2_XTB] __force_mask__ = [False, True, True] __energy_unit__ = "kcal/mol" __distance_unit__ = "ang" @@ -31,7 +31,7 @@ def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None: return super()._post_init(overwrite_local_cache, energy_unit, distance_unit) def setup_dummy(self): - n_atoms = np.array([np.random.randint(1, 100) for _ in range(len(self))]) + n_atoms = np.array([np.random.randint(2, 100) for _ in range(len(self))]) position_idx_range = np.concatenate([[0], np.cumsum(n_atoms)]).repeat(2)[1:-1].reshape(-1, 2) atomic_inputs = np.concatenate( [ diff --git a/openqdc/datasets/statistics.py b/openqdc/datasets/statistics.py index e4fe9e5..2122271 100644 --- a/openqdc/datasets/statistics.py +++ b/openqdc/datasets/statistics.py @@ -21,7 +21,8 @@ def to_dict(self): def transform(self, func): for k, v in self.to_dict().items(): - setattr(self, k, func(v)) + if v is not None: + setattr(self, k, func(v)) @dataclass diff --git a/openqdc/raws/config_factory.py b/openqdc/raws/config_factory.py index d39e0bb..26e0f2c 100644 --- a/openqdc/raws/config_factory.py +++ b/openqdc/raws/config_factory.py @@ -204,6 +204,14 @@ class DataConfigFactory: links={"rdkit_folder.tar.gz": "https://dataverse.harvard.edu/api/access/datafile/4327252"}, ) + l7 = dict( + dataset_name="l7", + links={ + "l7.yaml": "http://cuby4.molecular.cz/download_datasets/l7.yaml", + "geometries.tar.gz": "http://cuby4.molecular.cz/download_geometries/L7.tar", + }, + ) + molecule3d = dict( dataset_name="molecule3d", links={"molecule3d.zip": "https://drive.google.com/uc?id=1C_KRf8mX-gxny7kL9ACNCEV4ceu_fUGy"}, @@ -239,22 +247,36 @@ class DataConfigFactory: links={"spice-2.0.0.hdf5": "https://zenodo.org/records/10835749/files/SPICE-2.0.0.hdf5?download=1"}, ) - dess = dict( - dataset_name="dess5m", + splinter = dict( + dataset_name="splinter", links={ - "DESS5M.zip": "https://zenodo.org/record/5706002/files/DESS5M.zip", - "DESS370.zip": "https://zenodo.org/record/5676266/files/DES370K.zip", + "dimerpairs.0.tar.gz": "https://figshare.com/ndownloader/files/39449167", + "dimerpairs.1.tar.gz": "https://figshare.com/ndownloader/files/40271983", + "dimerpairs.2.tar.gz": "https://figshare.com/ndownloader/files/40271989", + "dimerpairs.3.tar.gz": "https://figshare.com/ndownloader/files/40272001", + "dimerpairs.4.tar.gz": "https://figshare.com/ndownloader/files/40272022", + "dimerpairs.5.tar.gz": "https://figshare.com/ndownloader/files/40552931", + "dimerpairs.6.tar.gz": "https://figshare.com/ndownloader/files/40272040", + "dimerpairs.7.tar.gz": "https://figshare.com/ndownloader/files/40272052", + "dimerpairs.8.tar.gz": "https://figshare.com/ndownloader/files/40272061", + "dimerpairs.9.tar.gz": "https://figshare.com/ndownloader/files/40272064", + "dimerpairs_nonstandard.tar.gz": "https://figshare.com/ndownloader/files/40272067", + "lig_interaction_sites.sdf": "https://figshare.com/ndownloader/files/40272070", + "lig_monomers.sdf": "https://figshare.com/ndownloader/files/40272073", + "prot_interaction_sites.sdf": "https://figshare.com/ndownloader/files/40272076", + "prot_monomers.sdf": "https://figshare.com/ndownloader/files/40272079", + "merge_monomers.py": "https://figshare.com/ndownloader/files/41807682", }, ) - des370k_interaction = dict( + des370k = dict( dataset_name="des370k_interaction", links={ "DES370K.zip": "https://zenodo.org/record/5676266/files/DES370K.zip", }, ) - des5m_interaction = dict( + des5m = dict( dataset_name="des5m_interaction", links={ "DES5M.zip": "https://zenodo.org/records/5706002/files/DESS5M.zip?download=1", @@ -269,6 +291,11 @@ class DataConfigFactory: }, ) + metcalf = dict( + dataset_name="metcalf", + links={"model-data.tar.gz": "https://zenodo.org/records/10934211/files/model-data.tar?download=1"}, + ) + misato = dict( dataset_name="misato", links={ @@ -314,12 +341,12 @@ class DataConfigFactory: links={"Transition1x.h5": "https://figshare.com/ndownloader/files/36035789"}, ) - des_s66 = dict( + dess66 = dict( dataset_name="des_s66", links={"DESS66.zip": "https://zenodo.org/records/5676284/files/DESS66.zip?download=1"}, ) - des_s66x8 = dict( + dess66x8 = dict( dataset_name="des_s66x8", links={"DESS66x8.zip": "https://zenodo.org/records/5676284/files/DESS66x8.zip?download=1"}, ) @@ -328,6 +355,14 @@ class DataConfigFactory: links={"revmd17.zip": "https://figshare.com/ndownloader/articles/12672038/versions/3"}, ) + x40 = dict( + dataset_name="x40", + links={ + "x40.yaml": "http://cuby4.molecular.cz/download_datasets/x40.yaml", + "geometries.tar.gz": "http://cuby4.molecular.cz/download_geometries/X40.tar", + }, + ) + available_datasets = [k for k in locals().keys() if not k.startswith("__")] def __init__(self): diff --git a/openqdc/utils/preprocess.py b/openqdc/utils/preprocess.py index a7dd9c7..0fee22b 100644 --- a/openqdc/utils/preprocess.py +++ b/openqdc/utils/preprocess.py @@ -7,7 +7,7 @@ from openqdc import AVAILABLE_DATASETS options = list(AVAILABLE_DATASETS.values()) -options_map = {d.__name__: d for d in options} +options_map = {d.__name__.lower(): d for d in options} @click.command() diff --git a/tests/test_dummy.py b/tests/test_dummy.py index 0bd51af..e38a6dc 100644 --- a/tests/test_dummy.py +++ b/tests/test_dummy.py @@ -5,13 +5,20 @@ import numpy as np import pytest +from openqdc.datasets.interaction.dummy import DummyInteraction # noqa: E402 from openqdc.datasets.potential.dummy import Dummy # noqa: E402 from openqdc.utils.io import get_local_cache from openqdc.utils.package_utils import has_package + # start by removing any cached data -cache_dir = get_local_cache() -os.system(f"rm -rf {cache_dir}/dummy") +@pytest.fixture(autouse=True) +def clean_before_run(): + # start by removing any cached data + cache_dir = get_local_cache() + os.system(f"rm -rf {cache_dir}/dummy") + os.system(f"rm -rf {cache_dir}/dummy_interaction") + yield if has_package("torch"): @@ -28,29 +35,30 @@ @pytest.fixture -def ds(): +def dummy(): return Dummy() -def test_dummy(ds): - assert len(ds) > 10 - assert ds[100] +@pytest.fixture +def dummy_interaction(): + return DummyInteraction() -# def test_is_at_factory(): -# res = IsolatedAtomEnergyFactory.get("mp2/cc-pvdz") -# assert len(res) == len(ISOLATED_ATOM_ENERGIES["mp2"]["cc-pvdz"]) -# res = IsolatedAtomEnergyFactory.get("PM6") -# assert len(res) == len(ISOLATED_ATOM_ENERGIES["pm6"]) -# assert isinstance(res[("H", 0)], float) +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_dummy(ds, request): + ds = request.getfixturevalue(ds) + assert ds is not None + assert len(ds) == 9999 + assert ds[100] +@pytest.mark.parametrize("interaction_ds", [False, True]) @pytest.mark.parametrize("format", ["numpy", "torch", "jax"]) -def test_array_format(format): +def test_dummy_array_format(interaction_ds, format): if not has_package(format): pytest.skip(f"{format} is not installed, skipping test") - ds = Dummy(array_format=format) + ds = DummyInteraction(array_format=format) if interaction_ds else Dummy(array_format=format) keys = [ "positions", @@ -59,22 +67,26 @@ def test_array_format(format): "energies", "forces", "e0", - "formation_energies", - "per_atom_formation_energies", ] + if not interaction_ds: + # additional keys returned from the potential dataset + keys.extend(["formation_energies", "per_atom_formation_energies"]) data = ds[0] for key in keys: + if data[key] is None: + continue assert isinstance(data[key], format_to_type[format]) -def test_transform(): +@pytest.mark.parametrize("interaction_ds", [False, True]) +def test_transform(interaction_ds): def custom_fn(bunch): # create new name bunch.new_key = bunch.name + bunch.subset return bunch - ds = Dummy(transform=custom_fn) + ds = DummyInteraction(transform=custom_fn) if interaction_ds else Dummy(transform=custom_fn) data = ds[0] @@ -82,14 +94,18 @@ def custom_fn(bunch): assert data["new_key"] == data["name"] + data["subset"] -def test_get_statistics(ds): +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_get_statistics(ds, request): + ds = request.getfixturevalue(ds) stats = ds.get_statistics() keys = ["ForcesCalculatorStats", "FormationEnergyStats", "PerAtomFormationEnergyStats", "TotalEnergyStats"] assert all(k in stats for k in keys) -def test_energy_statistics_shapes(ds): +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_energy_statistics_shapes(ds, request): + ds = request.getfixturevalue(ds) stats = ds.get_statistics() num_methods = len(ds.energy_methods) @@ -107,7 +123,9 @@ def test_energy_statistics_shapes(ds): assert total_energy_stats["std"].shape == (1, num_methods) -def test_force_statistics_shapes(ds): +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_force_statistics_shapes(ds, request): + ds = request.getfixturevalue(ds) stats = ds.get_statistics() num_force_methods = len(ds.force_methods) @@ -115,21 +133,25 @@ def test_force_statistics_shapes(ds): keys = ["mean", "std", "component_mean", "component_std", "component_rms"] assert all(k in forces_stats for k in keys) - assert forces_stats["mean"].shape == (1, num_force_methods) - assert forces_stats["std"].shape == (1, num_force_methods) - assert forces_stats["component_mean"].shape == (3, num_force_methods) - assert forces_stats["component_std"].shape == (3, num_force_methods) - assert forces_stats["component_rms"].shape == (3, num_force_methods) + if len(ds.force_methods) > 0: + assert forces_stats["mean"].shape == (1, num_force_methods) + assert forces_stats["std"].shape == (1, num_force_methods) + assert forces_stats["component_mean"].shape == (3, num_force_methods) + assert forces_stats["component_std"].shape == (3, num_force_methods) + assert forces_stats["component_rms"].shape == (3, num_force_methods) +@pytest.mark.parametrize("interaction_ds", [False, True]) @pytest.mark.parametrize("format", ["numpy", "torch", "jax"]) -def test_stats_array_format(format): +def test_stats_array_format(interaction_ds, format): if not has_package(format): pytest.skip(f"{format} is not installed, skipping test") - ds = Dummy(array_format=format) + ds = DummyInteraction(array_format=format) if interaction_ds else Dummy(array_format=format) stats = ds.get_statistics() for key in stats.keys(): for k, v in stats[key].items(): + if v is None: + continue assert isinstance(v, format_to_type[format])