Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per Batch Padded Dataset #281

Merged
merged 15 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import jax

Expand All @@ -8,3 +9,5 @@
from apax.utils.helpers import setup_ase

setup_ase()

warnings.filterwarnings("ignore", message=".*os.fork()*")
7 changes: 4 additions & 3 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>
dataset:
name: cached
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer type over name. But I see why you don't want to use that as attribute. Maybe method or processing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used name to be consistent with the other uses of discriminative unions in our config.
I think method would be fine but I think we should only use one keyword for this purpose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change it to processing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning against name was, that dataset:name reminds me of train, test or SPICE etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah I can see how this might be confusing. thanks, I changed it 👍

shuffle_buffer_size: 1000

additional_properties_info: {}
ds_type: cached

n_train: 1000
n_valid: 100
Expand All @@ -31,8 +34,6 @@ data:
scale_method: "per_element_force_rms_scale"
scale_options: {}

shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV

Expand Down
66 changes: 64 additions & 2 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,66 @@
log = logging.getLogger(__name__)


class DatasetConfig(BaseModel, extra="forbid"):
name: str


class CachedDataset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly during the first epoch and stored to disk using
tf.data's cache.
Most performant option for datasets with samples of very similar size.

Parameters
----------
shuffle_buffer_size : int
| Size of the buffer that is shuffled by tf.data.
| Larger values require more RAM.
"""

name: Literal["cached"] = "cached"
shuffle_buffer_size: PositiveInt = 1000


class OTFDataset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly and fed into a tf.data generator.
Mostly for internal purposes.

Parameters
----------
shuffle_buffer_size : int
| Size of the buffer that is shuffled by tf.data.
| Larger values require more RAM.
"""

name: Literal["otf"] = "otf"
shuffle_buffer_size: PositiveInt = 1000


class PBPDatset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the next larges power of two.
This limits the compute wasted due to padding at the (negligible)
cost of some recompilations.
The NL is computed on-the-fly in parallel for `num_workers` of batches.
Does not use tf.data.
Comment on lines +65 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume PB stand for parallel batches. Maybe mention that once somewhere in the docstring, so that it is clear. I would also not write MP but materials project

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class name stands for PerBatchPadded . I guess I can just write it out.same for materials project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I don't see "PB" written anywhere


Most performant option for datasets with significantly differently sized systems
(e.g. MP, SPICE).

Parameters
----------
num_workers : int
| Number of batches to be processed in parallel.
"""

name: Literal["pbp"] = "pbp"
num_workers: PositiveInt = 10


class DataConfig(BaseModel, extra="forbid"):
"""
Configuration for data loading, preprocessing and training.
Expand Down Expand Up @@ -59,7 +119,10 @@ class DataConfig(BaseModel, extra="forbid"):

directory: str
experiment: str
ds_type: Literal["cached", "otf"] = "cached"
dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field(
CachedDataset(name="cached"), discriminator="name"
)

data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
Expand All @@ -69,7 +132,6 @@ class DataConfig(BaseModel, extra="forbid"):
n_valid: PositiveInt = 100
batch_size: PositiveInt = 32
valid_batch_size: PositiveInt = 100
shuffle_buffer_size: PositiveInt = 1000
additional_properties_info: dict[str, str] = {}

shift_method: str = "per_element_regression_shift"
Expand Down
219 changes: 216 additions & 3 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import uuid
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from random import shuffle
from typing import Dict, Iterator
Expand All @@ -11,7 +12,12 @@
import tensorflow as tf

from apax.data.preprocessing import compute_nl, prefetch_to_single_device
from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict
from apax.utils.convert import (
atoms_to_inputs,
atoms_to_labels,
transpose_dict_of_lists,
unit_dict,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,24 +75,26 @@ def find_largest_system(inputs, r_max) -> tuple[int]:


class InMemoryDataset:
"""Baseclass for all datasets which store data in memory."""

def __init__(
self,
atoms_list,
cutoff,
bs,
n_epochs,
buffer_size=1000,
n_jit_steps=1,
pos_unit: str = "Ang",
energy_unit: str = "eV",
pre_shuffle=False,
shuffle_buffer_size=1000,
ignore_labels=False,
cache_path=".",
) -> None:
self.n_epochs = n_epochs
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
self.buffer_size = buffer_size
self.buffer_size = shuffle_buffer_size
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit
Expand Down Expand Up @@ -224,6 +232,13 @@ def cleanup(self):


class CachedInMemoryDataset(InMemoryDataset):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly during the first epoch and stored to disk using
tf.data's cache.
Most performant option for datasets with samples of very similar size.
"""

def __iter__(self):
while self.count < self.n_data or len(self.buffer) > 0:
yield self.buffer.popleft()
Expand Down Expand Up @@ -283,6 +298,12 @@ def cleanup(self):


class OTFInMemoryDataset(InMemoryDataset):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly and fed into a tf.data generator.
Mostly for internal purposes.
"""

def __iter__(self):
outer_count = 0
max_iter = self.n_data * self.n_epochs
Expand Down Expand Up @@ -332,7 +353,199 @@ def batch(self, sharding=None) -> Iterator[jax.Array]:
return ds


def next_power_of_two(x):
return 1 << (int(x) - 1).bit_length()


class BatchProcessor:
def __init__(self, cutoff, forces=True, stress=False) -> None:
self.cutoff = cutoff
self.forces = forces
self.stress = stress

def __call__(self, samples: list[dict]):
inputs = {
"numbers": [],
"n_atoms": [],
"positions": [],
"box": [],
"idx": [],
"offsets": [],
}

labels = {
"energy": [],
}

if self.forces:
labels["forces"] = []
if self.stress:
labels["stress"] = []

for sample in samples:
inp, lab = sample

inputs["numbers"].append(inp["numbers"])
inputs["n_atoms"].append(inp["n_atoms"])
inputs["positions"].append(inp["positions"])
inputs["box"].append(inp["box"])
idx, offsets = compute_nl(inp["positions"], inp["box"], self.cutoff)
inputs["idx"].append(idx)
inputs["offsets"].append(offsets)

labels["energy"].append(lab["energy"])
if self.forces:
labels["forces"].append(lab["forces"])
if self.stress:
labels["stress"].append(lab["stress"])

max_atoms = np.max(inputs["n_atoms"])
max_nbrs = np.max([idx.shape[1] for idx in inputs["idx"]])

max_atoms = next_power_of_two(max_atoms)
max_nbrs = next_power_of_two(max_nbrs)

for i in range(len(inputs["n_atoms"])):
inputs["idx"][i], inputs["offsets"][i] = pad_nl(
inputs["idx"][i], inputs["offsets"][i], max_nbrs
)

zeros_to_add = max_atoms - inputs["numbers"][i].shape[0]
inputs["positions"][i] = np.pad(
inputs["positions"][i], ((0, zeros_to_add), (0, 0)), "constant"
)
inputs["numbers"][i] = np.pad(
inputs["numbers"][i], (0, zeros_to_add), "constant"
).astype(np.int16)

if "forces" in labels:
labels["forces"][i] = np.pad(
labels["forces"][i], ((0, zeros_to_add), (0, 0)), "constant"
)

inputs = {k: np.array(v) for k, v in inputs.items()}
labels = {k: np.array(v) for k, v in labels.items()}
return inputs, labels


class PerBatchPaddedDataset(InMemoryDataset):
"""Dataset which pads everything (atoms, neighbors)
to the next larges power of two.
This limits the compute wasted due to padding at the (negligible)
cost of some recompilations.
The NL is computed on-the-fly in parallel for `num_workers` of batches.
Does not use tf.data.

Most performant option for datasets with significantly differently sized systems
(e.g. MP, SPICE).
"""

def __init__(
self,
atoms_list,
cutoff,
bs,
n_epochs,
n_jit_steps=1,
buffer_size=20,
num_workers=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to set num_workers to None and get the default from the number of available cores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. I'll update the default.

pos_unit: str = "Ang",
energy_unit: str = "eV",
pre_shuffle=False,
) -> None:
self.cutoff = cutoff

if n_jit_steps > 1:
raise "PerBatchPaddedDataset is not yet compatible with multi step jit"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure but in general it's better to raise a concrete value, like here raise TypeError(msg...) instead, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, not doing so was not intended. Thanks for pointing it out


self.n_jit_steps = n_jit_steps
self.n_epochs = n_epochs
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit

self.buffer_size = buffer_size
self.batch_size = bs

self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
label_keys = self.labels.keys()

self.data = list(
zip(
transpose_dict_of_lists(self.inputs), transpose_dict_of_lists(self.labels)
)
)

forces = "forces" in label_keys
stress = "stress" in label_keys
self.prepare_batch = BatchProcessor(cutoff, forces, stress)

self.count = 0
self.max_count = self.n_epochs * self.steps_per_epoch()
self.buffer = deque()
self.num_workers = num_workers
self.process_pool = ProcessPoolExecutor(self.num_workers)

def enqueue(self, num_batches):
start = self.count * self.batch_size

dataset_chunks = [
self.data[start + self.batch_size * i : start + self.batch_size * (i + 1)]
for i in range(0, num_batches)
]
for batch in self.process_pool.map(self.prepare_batch, dataset_chunks):
self.buffer.append(batch)

self.count += num_batches

def __iter__(self):
for n in range(self.n_epochs):
self.count = 0
self.buffer = deque()

if self.should_shuffle:
shuffle(self.data)

self.enqueue(min(self.buffer_size, self.n_data // self.batch_size))

for i in range(self.steps_per_epoch()):
batch = self.buffer.popleft()
yield batch

current_buffer_len = len(self.buffer)
space = self.buffer_size - current_buffer_len

if space >= self.num_workers:
more_data = min(space, self.steps_per_epoch() - self.count)
more_data = max(more_data, 0)
if more_data > 0:
self.enqueue(more_data)

def shuffle_and_batch(self, sharding):
self.should_shuffle = True

ds = prefetch_to_single_device(
iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1
)

return ds

def batch(self, sharding) -> Iterator[jax.Array]:
self.should_shuffle = False
ds = prefetch_to_single_device(
iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def make_signature(self) -> None:
pass


dataset_dict = {
"cached": CachedInMemoryDataset,
"otf": OTFInMemoryDataset,
"pbp": PerBatchPaddedDataset,
}
Loading
Loading