Skip to content

Commit

Permalink
update default data shapes for interaction datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
mcneela committed Mar 7, 2024
1 parent f046ea9 commit 3c84ee9
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
set_cache_dir,
)
from openqdc.datasets.potential.base import BaseDataset
from openqdc.utils.constants import (
NB_ATOMIC_FEATURES
)

from loguru import logger

Expand Down Expand Up @@ -43,3 +46,12 @@ def collate_list(self, list_entries: List[Dict]):
res["position_idx_range"] = x

return res

@property
def data_shapes(self):
return {
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"energies": (-1, len(self.__energy_methods__)),
"forces": (-1, 3, len(self.force_target_names)),
}

0 comments on commit 3c84ee9

Please sign in to comment.