Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch/jax Dataloader support #55

Merged
merged 17 commits into from
Apr 5, 2024
60 changes: 50 additions & 10 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import partial
from itertools import compress
from os.path import join as p_join
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
from ase.io.extxyz import write_extxyz
Expand Down Expand Up @@ -37,10 +37,29 @@
push_remote,
set_cache_dir,
)
from openqdc.utils.package_utils import requires_package
from openqdc.utils.package_utils import has_package, requires_package
from openqdc.utils.regressor import Regressor # noqa
from openqdc.utils.units import get_conversion

if has_package("torch"):
import torch

if has_package("jax"):
import jax.numpy as jnp


@requires_package("torch")
def to_torch(x: np.ndarray):
return torch.from_numpy(x)


@requires_package("jax")
def to_jax(x: np.ndarray):
return jnp.array(x)


_CONVERT_DICT = {"torch": to_torch, "jax": to_jax, "numpy": lambda x: x}


class BaseDataset(DatasetPropertyMixIn):
"""
Expand All @@ -65,10 +84,12 @@ def __init__(
self,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "numpy",
energy_type: str = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
transform: Optional[Callable] = None,
regressor_kwargs={
"solver_type": "linear",
"sub_sample": None,
Expand All @@ -83,6 +104,8 @@ def __init__(
Energy unit to convert dataset to. Supported units: ["kcal/mol", "kj/mol", "hartree", "ev"]
distance_unit
Distance unit to convert dataset to. Supported units: ["ang", "nm", "bohr"]
array_format
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
energy_type
Type of isolated atom energy to use for the dataset. Default: "formation"
Supported types: ["formation", "regression", "null"]
Expand All @@ -92,6 +115,8 @@ def __init__(
Cache directory location. Defaults to "~/.cache/openqdc"
recompute_statistics
Whether to recompute the statistics of the dataset.
transform, optional
transformation to apply to the __getitem__ calls
regressor_kwargs
Dictionary of keyword arguments to pass to the regressor.
Default: {"solver_type": "linear", "sub_sample": None, "stride": 1}
Expand All @@ -101,12 +126,14 @@ def __init__(
self.data = None
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self.energy_type = energy_type
self.refit_e0s = recompute_statistics or overwrite_local_cache
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
else:
self.read_preprocess(overwrite_local_cache=overwrite_local_cache)
self.set_array_format(array_format)
self._post_init(overwrite_local_cache, energy_unit, distance_unit)

def _post_init(
Expand Down Expand Up @@ -270,6 +297,10 @@ def set_distance_unit(self, value: str):
self.__distance_unit__ = value
self.__class__.__fn_distance__ = get_conversion(old_unit, value)

def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
self.array_format = format

def read_raw_entries(self):
raise NotImplementedError

Expand Down Expand Up @@ -536,24 +567,28 @@ def __smiles_converter__(self, x):
"""
return x

def _convert_array(self, x: np.ndarray):
return _CONVERT_DICT.get(self.array_format)(x)

def __getitem__(self, idx: int):
shift = MAX_CHARGE
p_start, p_end = self.data["position_idx_range"][idx]
input = self.data["atomic_inputs"][p_start:p_end]
z, c, positions, energies = (
np.array(input[:, 0], dtype=np.int32),
np.array(input[:, 1], dtype=np.int32),
np.array(input[:, -3:], dtype=np.float32),
np.array(self.data["energies"][idx], dtype=np.float32),
self._convert_array(np.array(input[:, 0], dtype=np.int32)),
self._convert_array(np.array(input[:, 1], dtype=np.int32)),
self._convert_array(np.array(input[:, -3:], dtype=np.float32)),
self._convert_array(np.array(self.data["energies"][idx], dtype=np.float32)),
)
name = self.__smiles_converter__(self.data["name"][idx])
subset = self.data["subset"][idx]
e0s = self.__isolated_atom_energies__[..., z, c + shift].T
formation_energies = (energies - e0s.sum(axis=0)).astype(np.float32)
e0s = self._convert_array(self.__isolated_atom_energies__[..., z, c + shift].T)
formation_energies = energies - e0s.sum(axis=0)
forces = None
if "forces" in self.data:
forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32)
return Bunch(
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))

bunch = Bunch(
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
positions=positions,
atomic_numbers=z,
charges=c,
Expand All @@ -565,3 +600,8 @@ def __getitem__(self, idx: int):
subset=subset,
forces=forces,
)

if self.transform is not None:
bunch = self.transform(bunch)

return bunch
22 changes: 15 additions & 7 deletions openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,39 @@ def __getitem__(self, idx: int):
p_start, p_end = self.data["position_idx_range"][idx]
input = self.data["atomic_inputs"][p_start:p_end]
z, c, positions, energies = (
np.array(input[:, 0], dtype=np.int32),
np.array(input[:, 1], dtype=np.int32),
np.array(input[:, -3:], dtype=np.float32),
np.array(self.data["energies"][idx], dtype=np.float32),
self._convert_array(np.array(input[:, 0], dtype=np.int32)),
self._convert_array(np.array(input[:, 1], dtype=np.int32)),
self._convert_array(np.array(input[:, -3:], dtype=np.float32)),
self._convert_array(np.array(self.data["energies"][idx], dtype=np.float32)),
)
name = self.__smiles_converter__(self.data["name"][idx])
subset = self.data["subset"][idx]
n_atoms_first = self.data["n_atoms_first"][idx]

if "forces" in self.data:
forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32)
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end]), dtype=np.float32)
else:
forces = None
return Bunch(

e0 = self._convert_array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32)

bunch = Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=self.__isolated_atom_energies__[..., z, c + shift].T,
e0=e0,
energies=energies,
name=name,
subset=subset,
forces=forces,
n_atoms_first=n_atoms_first,
)

if self.transform is not None:
bunch = self.transform(bunch)

return bunch

def save_preprocess(self, data_dict):
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
Expand Down
6 changes: 5 additions & 1 deletion openqdc/datasets/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Callable, List, Optional

import datamol as dm
import numpy as np
Expand All @@ -25,7 +25,9 @@ def __init__(
energy_type: Optional[str] = "regression",
energy_unit: Optional[str] = "hartree",
distance_unit: Optional[str] = "ang",
array_format: Optional[str] = "numpy",
level_of_theory: Optional[QmMethod] = None,
transform: Optional[Callable] = None,
regressor_kwargs={
"solver_type": "linear",
"sub_sample": None,
Expand All @@ -49,7 +51,9 @@ def __init__(
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self._read_and_preprocess()
self.set_array_format(array_format)
self._post_init(True, energy_unit, distance_unit)

def __str__(self):
Expand Down
17 changes: 5 additions & 12 deletions openqdc/datasets/potential/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,9 @@ def _stats(self):
},
}

def __init__(
self,
energy_unit=None,
distance_unit=None,
cache_dir=None,
) -> None:
try:
super().__init__(energy_unit=energy_unit, distance_unit=distance_unit, cache_dir=cache_dir)

except: # noqa
pass
self._set_isolated_atom_energies()
def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit)

def setup_dummy(self):
n_atoms = np.array([np.random.randint(1, 100) for _ in range(len(self))])
Expand Down Expand Up @@ -89,6 +79,9 @@ def setup_dummy(self):
)
self.__average_nb_atoms__ = self.data["n_atoms"].mean()

def read_preprocess(self, overwrite_local_cache=False):
return

def is_preprocessed(self):
return True

Expand Down
7 changes: 7 additions & 0 deletions openqdc/datasets/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class StatisticManager:
_results = {}

def __init__(self, dataset, recompute: bool = False, *statistic_calculators: "AbstractStatsCalculator"):
self.reset_state()
self._statistic_calculators = [
statistic_calculators.from_openqdc_dataset(dataset, recompute)
for statistic_calculators in statistic_calculators
Expand All @@ -82,6 +83,12 @@ def state(self) -> dict:
"""
return self._state

def reset_state(self):
"""
Reset the state dictionary
"""
self._state = {}

def get_state(self, key: Optional[str] = None):
"""
key : str, default = None
Expand Down
53 changes: 53 additions & 0 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
"""Path hack to make tests work."""

import numpy as np
import pytest

from openqdc.datasets.potential.dummy import Dummy # noqa: E402
from openqdc.utils.package_utils import has_package

if has_package("torch"):
import torch

if has_package("jax"):
import jax

format_to_type = {
"numpy": np.ndarray,
"torch": torch.Tensor if has_package("torch") else None,
"jax": jax.numpy.ndarray if has_package("jax") else None,
}


def test_dummy():
Expand All @@ -15,3 +31,40 @@ def test_dummy():
# res = IsolatedAtomEnergyFactory.get("PM6")
# assert len(res) == len(ISOLATED_ATOM_ENERGIES["pm6"])
# assert isinstance(res[("H", 0)], float)


@pytest.mark.parametrize("format", ["numpy", "torch", "jax"])
def test_array_format(format):
if not has_package(format):
pytest.skip(f"{format} is not installed, skipping test")

ds = Dummy(array_format=format)

keys = [
"positions",
"atomic_numbers",
"charges",
"energies",
"forces",
"e0",
"formation_energies",
"per_atom_formation_energies",
]

data = ds[0]
for key in keys:
assert isinstance(data[key], format_to_type[format])


def test_transform():
def custom_fn(bunch):
# create new name
bunch.new_key = bunch.name + bunch.subset
return bunch

ds = Dummy(transform=custom_fn)

data = ds[0]

assert "new_key" in data
assert data["new_key"] == data["name"] + data["subset"]
37 changes: 37 additions & 0 deletions tests/test_filedataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from io import StringIO

import numpy as np
import pytest

from openqdc.datasets.io import XYZDataset
from openqdc.methods.enums import PotentialMethod
from openqdc.utils.package_utils import has_package

if has_package("torch"):
import torch

if has_package("jax"):
import jax

format_to_type = {
"numpy": np.ndarray,
"torch": torch.Tensor if has_package("torch") else None,
"jax": jax.numpy.ndarray if has_package("jax") else None,
}


@pytest.fixture
Expand All @@ -27,3 +41,26 @@ def test_xyz_dataset(xyz_filelike):
assert len(ds.numbers) == 3
assert ds[1].energies == -20.0
assert set(ds.chemical_species) == {"H", "O", "C"}


@pytest.mark.parametrize("format", ["numpy", "torch", "jax"])
def test_array_format(xyz_filelike, format):
if not has_package(format):
pytest.skip(f"{format} is not installed, skipping test")

ds = XYZDataset(path=[xyz_filelike], array_format=format)

keys = [
"positions",
"atomic_numbers",
"charges",
"energies",
"forces",
"e0",
"formation_energies",
"per_atom_formation_energies",
]

data = ds[0]
for key in keys:
assert isinstance(getattr(data, key), format_to_type[format])
Loading