Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per Batch Padded Dataset #281

Merged
merged 15 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import jax

Expand All @@ -8,3 +9,5 @@
from apax.utils.helpers import setup_ase

setup_ase()

warnings.filterwarnings("ignore", message=".*os.fork()*")
7 changes: 4 additions & 3 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>
dataset:
processing: cached
shuffle_buffer_size: 1000

additional_properties_info: {}
ds_type: cached

n_train: 1000
n_valid: 100
Expand All @@ -31,8 +34,6 @@ data:
scale_method: "per_element_force_rms_scale"
scale_options: {}

shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV

Expand Down
66 changes: 64 additions & 2 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,66 @@
log = logging.getLogger(__name__)


class DatasetConfig(BaseModel, extra="forbid"):
processing: str


class CachedDataset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly during the first epoch and stored to disk using
tf.data's cache.
Most performant option for datasets with samples of very similar size.

Parameters
----------
shuffle_buffer_size : int
| Size of the buffer that is shuffled by tf.data.
| Larger values require more RAM.
"""

processing: Literal["cached"] = "cached"
shuffle_buffer_size: PositiveInt = 1000


class OTFDataset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly and fed into a tf.data generator.
Mostly for internal purposes.

Parameters
----------
shuffle_buffer_size : int
| Size of the buffer that is shuffled by tf.data.
| Larger values require more RAM.
"""

processing: Literal["otf"] = "otf"
shuffle_buffer_size: PositiveInt = 1000


class PBPDatset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the next larges power of two.
This limits the compute wasted due to padding at the (negligible)
cost of some recompilations.
The NL is computed on-the-fly in parallel for `num_workers` of batches.
Does not use tf.data.
Comment on lines +65 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume PB stand for parallel batches. Maybe mention that once somewhere in the docstring, so that it is clear. I would also not write MP but materials project

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class name stands for PerBatchPadded . I guess I can just write it out.same for materials project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I don't see "PB" written anywhere


Most performant option for datasets with significantly differently sized systems
(e.g. MP, SPICE).

Parameters
----------
num_workers : int
| Number of batches to be processed in parallel.
"""

processing: Literal["pbp"] = "pbp"
num_workers: PositiveInt = 10


class DataConfig(BaseModel, extra="forbid"):
"""
Configuration for data loading, preprocessing and training.
Expand Down Expand Up @@ -59,7 +119,10 @@ class DataConfig(BaseModel, extra="forbid"):

directory: str
experiment: str
ds_type: Literal["cached", "otf"] = "cached"
dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field(
CachedDataset(processing="cached"), discriminator="processing"
)

data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
Expand All @@ -69,7 +132,6 @@ class DataConfig(BaseModel, extra="forbid"):
n_valid: PositiveInt = 100
batch_size: PositiveInt = 32
valid_batch_size: PositiveInt = 100
shuffle_buffer_size: PositiveInt = 1000
additional_properties_info: dict[str, str] = {}

shift_method: str = "per_element_regression_shift"
Expand Down
Loading
Loading