Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch/jax Dataloader support #55

Merged
merged 17 commits into from
Apr 5, 2024
65 changes: 54 additions & 11 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from itertools import compress
from os.path import join as p_join
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -39,10 +39,16 @@
set_cache_dir,
)
from openqdc.utils.molecule import atom_table, z_to_formula
from openqdc.utils.package_utils import requires_package
from openqdc.utils.package_utils import has_package, requires_package
from openqdc.utils.regressor import Regressor
from openqdc.utils.units import get_conversion

if has_package("torch"):
import torch

if has_package("jax"):
import jax.numpy as jnp


def _extract_entry(
df: pd.DataFrame,
Expand Down Expand Up @@ -110,9 +116,11 @@ def __init__(
self,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "numpy",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
transform: Optional[Callable] = None,
regressor_kwargs={
"solver_type": "linear",
"sub_sample": None,
Expand All @@ -127,12 +135,16 @@ def __init__(
Energy unit to convert dataset to. Supported units: ["kcal/mol", "kj/mol", "hartree", "ev"]
distance_unit
Distance unit to convert dataset to. Supported units: ["ang", "nm", "bohr"]
array_format
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
overwrite_local_cache
Whether to overwrite the locally cached dataset.
cache_dir
Cache directory location. Defaults to "~/.cache/openqdc"
recompute_statistics
Whether to recompute the statistics of the dataset.
transform, optional
transformation to apply to the __getitem__ calls
regressor_kwargs
Dictionary of keyword arguments to pass to the regressor.
Default: {"solver_type": "linear", "sub_sample": None, "stride": 1}
Expand All @@ -142,22 +154,25 @@ def __init__(
self.data = None
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
else:
self.read_preprocess(overwrite_local_cache=overwrite_local_cache)
self._post_init(overwrite_local_cache, energy_unit, distance_unit)
self._post_init(overwrite_local_cache, energy_unit, distance_unit, array_format)

def _post_init(
self,
overwrite_local_cache: bool = False,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: Optional[str] = None,
) -> None:
self._set_units(None, None)
self._set_isolated_atom_energies()
self._precompute_statistics(overwrite_local_cache=overwrite_local_cache)
self._set_units(energy_unit, distance_unit)
self._set_array_format(array_format)
self._convert_data()
self._set_isolated_atom_energies()

Expand Down Expand Up @@ -393,6 +408,10 @@ def force_mask(self):
self.__class__.__force_mask__ = [False] * len(self.energy_methods)
return self.__class__.__force_mask__

def _set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
self.array_format = format

def _set_units(self, en, ds):
old_en, old_ds = self.energy_unit, self.distance_unit
en = en if en is not None else old_en
Expand Down Expand Up @@ -732,31 +751,55 @@ def __smiles_converter__(self, x):
"""
return x

@requires_package("torch")
def _convert_to_torch(self, x: np.ndarray):
return torch.from_numpy(x)

@requires_package("jax")
def _convert_to_jax(self, x: np.ndarray):
return jnp.array(x)

def _convert_array(self, x: np.ndarray):
if self.array_format == "torch":
return self._convert_to_torch(x)
elif self.array_format == "jax":
return self._convert_to_jax(x)
return x

def __getitem__(self, idx: int):
shift = IsolatedAtomEnergyFactory.max_charge
p_start, p_end = self.data["position_idx_range"][idx]
input = self.data["atomic_inputs"][p_start:p_end]
z, c, positions, energies = (
np.array(input[:, 0], dtype=np.int32),
np.array(input[:, 1], dtype=np.int32),
np.array(input[:, -3:], dtype=np.float32),
np.array(self.data["energies"][idx], dtype=np.float32),
self._convert_array(np.array(input[:, 0], dtype=np.int32)),
self._convert_array(np.array(input[:, 1], dtype=np.int32)),
self._convert_array(np.array(input[:, -3:], dtype=np.float32)),
self._convert_array(np.array(self.data["energies"][idx], dtype=np.float32)),
)
name = self.__smiles_converter__(self.data["name"][idx])
subset = self.data["subset"][idx]

if "forces" in self.data:
forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32)
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))
else:
forces = None
return Bunch(

e0 = self._convert_array(self.__isolated_atom_energies__[..., z, c + shift].T)
linear_e0 = self._convert_array(self.new_e0s[..., z, c + shift].T) if hasattr(self, "new_e0s") else None

bunch = Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=self.__isolated_atom_energies__[..., z, c + shift].T,
linear_e0=self.new_e0s[..., z, c + shift].T if hasattr(self, "new_e0s") else None,
e0=e0,
linear_e0=linear_e0,
energies=energies,
name=name,
subset=subset,
forces=forces,
)

if self.transform is not None:
bunch = self.transform(bunch)

return bunch
2 changes: 1 addition & 1 deletion openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def save_preprocess(self, data_dict):
for key in data_dict:
if key not in self.data_keys:
x = data_dict[key]
x[x == None] = -1
x[x == None] = -1 # noqa
data_dict[key] = np.unique(x, return_inverse=True)

with open(local_path, "wb") as f:
Expand Down
20 changes: 8 additions & 12 deletions openqdc/datasets/potential/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,15 @@ def _stats(self):
},
}

def __init__(
self,
energy_unit=None,
distance_unit=None,
cache_dir=None,
) -> None:
try:
super().__init__(energy_unit=energy_unit, distance_unit=distance_unit, cache_dir=cache_dir)

except: # noqa
pass
self._set_isolated_atom_energies()
def _post_init(self, overwrite_local_cache, energy_unit, distance_unit, array_format) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit, array_format)

def read_preprocess(self, overwrite_local_cache=False):
return

def _precompute_statistics(self, overwrite_local_cache=False):
return

def setup_dummy(self):
n_atoms = np.array([np.random.randint(1, 100) for _ in range(len(self))])
Expand Down
44 changes: 44 additions & 0 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
"""Path hack to make tests work."""

import numpy as np
import pytest

from openqdc.datasets.potential.dummy import Dummy # noqa: E402
from openqdc.utils.atomization_energies import (
ISOLATED_ATOM_ENERGIES,
IsolatedAtomEnergyFactory,
)
from openqdc.utils.package_utils import has_package

if has_package("torch"):
import torch

if has_package("jax"):
import jax

format_to_type = {
"numpy": np.ndarray,
"torch": torch.Tensor if torch else None,
"jax": jax.numpy.ndarray if jax else None,
}


def test_dummy():
Expand All @@ -19,3 +35,31 @@ def test_is_at_factory():
res = IsolatedAtomEnergyFactory.get("PM6")
assert len(res) == len(ISOLATED_ATOM_ENERGIES["pm6"])
assert isinstance(res[("H", 0)], float)


@pytest.mark.parametrize("format", ["numpy", "torch", "jax"])
def test_array_format(format):
if not has_package(format):
pytest.skip(f"{format} is not installed, skipping test")

ds = Dummy(array_format=format)

keys = ["positions", "atomic_numbers", "charges", "energies", "forces"]

data = ds[0]
for key in keys:
assert isinstance(data[key], format_to_type[format])


def test_transform():
def custom_fn(bunch):
# create new name
bunch.new_key = bunch.name + bunch.subset
return bunch

ds = Dummy(transform=custom_fn)

data = ds[0]

assert "new_key" in data
assert data["new_key"] == data["name"] + data["subset"]
Loading