Skip to content

Commit

Permalink
Added n_body count check
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Shenoy committed Apr 2, 2024
1 parent ad72d1a commit 187b6fc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions openqdc/datasets/interaction/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __len__(self):
return 9999


class NBodyDummy(DummyInteraction):
class NBodyDummyInteraction(DummyInteraction):
"""Dummy Interaction Dataset with N-body interactions
Note: we sample N for N-body from 3 to 5 randomly.
Expand All @@ -102,7 +102,7 @@ class NBodyDummy(DummyInteraction):
def setup_dummy(self):
super().setup_dummy()
data = self.data
n_body = np.random.randint(3, 5) # choose > 2 since default assumes 2
n_body = 5 # random value
n_atoms = data["n_atoms"]
data.update(
{
Expand Down
20 changes: 17 additions & 3 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import pytest

from openqdc.datasets.interaction.dummy import DummyInteraction # noqa: E402
from openqdc.datasets.interaction.dummy import ( # noqa: E402
DummyInteraction,
NBodyDummyInteraction,
)
from openqdc.datasets.potential.dummy import Dummy # noqa: E402
from openqdc.utils.atomization_energies import (
ISOLATED_ATOM_ENERGIES,
Expand All @@ -20,7 +23,12 @@ def dummy_interaction():
return DummyInteraction()


@pytest.mark.parametrize("cls", ["dummy", "dummy_interaction"])
@pytest.fixture
def n_body_dummy_interaction():
return NBodyDummyInteraction()


@pytest.mark.parametrize("cls", ["dummy", "dummy_interaction", "n_body_dummy_interaction"])
def test_basic(cls, request):
# init
ds = request.getfixturevalue(cls)
Expand All @@ -32,7 +40,7 @@ def test_basic(cls, request):
assert ds[0]


@pytest.mark.parametrize("cls", ["dummy", "dummy_interaction"])
@pytest.mark.parametrize("cls", ["dummy", "dummy_interaction", "n_body_dummy_interaction"])
@pytest.mark.parametrize(
"normalization",
[
Expand All @@ -50,6 +58,12 @@ def test_stats(cls, normalization, request):
assert stats is not None


def test_n_atoms_first(n_body_dummy_interaction):
item = n_body_dummy_interaction[0]
# shaped (1, n_body - 1)
assert item["n_atoms_first"].shape[1] == 4


def test_isolated_atom_factory():
res = IsolatedAtomEnergyFactory.get("mp2/cc-pvdz")
assert len(res) == len(ISOLATED_ATOM_ENERGIES["mp2"]["cc-pvdz"])
Expand Down

0 comments on commit 187b6fc

Please sign in to comment.