-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Per Batch Padded Dataset #281
Changes from 8 commits
e2beaf6
80ca14d
1c534c6
7cc567e
b6e09d6
bfb1152
94db422
4905419
c4b7bb2
24f60d7
a89e728
93cd227
dbae8bc
dcb44d6
d11dbfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,66 @@ | |
log = logging.getLogger(__name__) | ||
|
||
|
||
class DatasetConfig(BaseModel, extra="forbid"): | ||
name: str | ||
|
||
|
||
class CachedDataset(DatasetConfig, extra="forbid"): | ||
"""Dataset which pads everything (atoms, neighbors) | ||
to the largest system in the dataset. | ||
The NL is computed on the fly during the first epoch and stored to disk using | ||
tf.data's cache. | ||
Most performant option for datasets with samples of very similar size. | ||
|
||
Parameters | ||
---------- | ||
shuffle_buffer_size : int | ||
| Size of the buffer that is shuffled by tf.data. | ||
| Larger values require more RAM. | ||
""" | ||
|
||
name: Literal["cached"] = "cached" | ||
shuffle_buffer_size: PositiveInt = 1000 | ||
|
||
|
||
class OTFDataset(DatasetConfig, extra="forbid"): | ||
"""Dataset which pads everything (atoms, neighbors) | ||
to the largest system in the dataset. | ||
The NL is computed on the fly and fed into a tf.data generator. | ||
Mostly for internal purposes. | ||
|
||
Parameters | ||
---------- | ||
shuffle_buffer_size : int | ||
| Size of the buffer that is shuffled by tf.data. | ||
| Larger values require more RAM. | ||
""" | ||
|
||
name: Literal["otf"] = "otf" | ||
shuffle_buffer_size: PositiveInt = 1000 | ||
|
||
|
||
class PBPDatset(DatasetConfig, extra="forbid"): | ||
"""Dataset which pads everything (atoms, neighbors) | ||
to the next larges power of two. | ||
This limits the compute wasted due to padding at the (negligible) | ||
cost of some recompilations. | ||
The NL is computed on-the-fly in parallel for `num_workers` of batches. | ||
Does not use tf.data. | ||
Comment on lines
+65
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class name stands for PerBatchPadded . I guess I can just write it out.same for materials project. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, I don't see "PB" written anywhere |
||
|
||
Most performant option for datasets with significantly differently sized systems | ||
(e.g. MP, SPICE). | ||
|
||
Parameters | ||
---------- | ||
num_workers : int | ||
| Number of batches to be processed in parallel. | ||
""" | ||
|
||
name: Literal["pbp"] = "pbp" | ||
num_workers: PositiveInt = 10 | ||
|
||
|
||
class DataConfig(BaseModel, extra="forbid"): | ||
""" | ||
Configuration for data loading, preprocessing and training. | ||
|
@@ -59,7 +119,10 @@ class DataConfig(BaseModel, extra="forbid"): | |
|
||
directory: str | ||
experiment: str | ||
ds_type: Literal["cached", "otf"] = "cached" | ||
dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field( | ||
CachedDataset(name="cached"), discriminator="name" | ||
) | ||
|
||
data_path: Optional[str] = None | ||
train_data_path: Optional[str] = None | ||
val_data_path: Optional[str] = None | ||
|
@@ -69,7 +132,6 @@ class DataConfig(BaseModel, extra="forbid"): | |
n_valid: PositiveInt = 100 | ||
batch_size: PositiveInt = 32 | ||
valid_batch_size: PositiveInt = 100 | ||
shuffle_buffer_size: PositiveInt = 1000 | ||
additional_properties_info: dict[str, str] = {} | ||
|
||
shift_method: str = "per_element_regression_shift" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import logging | ||
import uuid | ||
from collections import deque | ||
from concurrent.futures import ProcessPoolExecutor | ||
from pathlib import Path | ||
from random import shuffle | ||
from typing import Dict, Iterator | ||
|
@@ -11,7 +12,12 @@ | |
import tensorflow as tf | ||
|
||
from apax.data.preprocessing import compute_nl, prefetch_to_single_device | ||
from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict | ||
from apax.utils.convert import ( | ||
atoms_to_inputs, | ||
atoms_to_labels, | ||
transpose_dict_of_lists, | ||
unit_dict, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
@@ -69,24 +75,26 @@ def find_largest_system(inputs, r_max) -> tuple[int]: | |
|
||
|
||
class InMemoryDataset: | ||
"""Baseclass for all datasets which store data in memory.""" | ||
|
||
def __init__( | ||
self, | ||
atoms_list, | ||
cutoff, | ||
bs, | ||
n_epochs, | ||
buffer_size=1000, | ||
n_jit_steps=1, | ||
pos_unit: str = "Ang", | ||
energy_unit: str = "eV", | ||
pre_shuffle=False, | ||
shuffle_buffer_size=1000, | ||
ignore_labels=False, | ||
cache_path=".", | ||
) -> None: | ||
self.n_epochs = n_epochs | ||
self.cutoff = cutoff | ||
self.n_jit_steps = n_jit_steps | ||
self.buffer_size = buffer_size | ||
self.buffer_size = shuffle_buffer_size | ||
self.n_data = len(atoms_list) | ||
self.batch_size = self.validate_batch_size(bs) | ||
self.pos_unit = pos_unit | ||
|
@@ -224,6 +232,13 @@ def cleanup(self): | |
|
||
|
||
class CachedInMemoryDataset(InMemoryDataset): | ||
"""Dataset which pads everything (atoms, neighbors) | ||
to the largest system in the dataset. | ||
The NL is computed on the fly during the first epoch and stored to disk using | ||
tf.data's cache. | ||
Most performant option for datasets with samples of very similar size. | ||
""" | ||
|
||
def __iter__(self): | ||
while self.count < self.n_data or len(self.buffer) > 0: | ||
yield self.buffer.popleft() | ||
|
@@ -283,6 +298,12 @@ def cleanup(self): | |
|
||
|
||
class OTFInMemoryDataset(InMemoryDataset): | ||
"""Dataset which pads everything (atoms, neighbors) | ||
to the largest system in the dataset. | ||
The NL is computed on the fly and fed into a tf.data generator. | ||
Mostly for internal purposes. | ||
""" | ||
|
||
def __iter__(self): | ||
outer_count = 0 | ||
max_iter = self.n_data * self.n_epochs | ||
|
@@ -332,7 +353,199 @@ def batch(self, sharding=None) -> Iterator[jax.Array]: | |
return ds | ||
|
||
|
||
def next_power_of_two(x): | ||
return 1 << (int(x) - 1).bit_length() | ||
|
||
|
||
class BatchProcessor: | ||
def __init__(self, cutoff, forces=True, stress=False) -> None: | ||
self.cutoff = cutoff | ||
self.forces = forces | ||
self.stress = stress | ||
|
||
def __call__(self, samples: list[dict]): | ||
inputs = { | ||
"numbers": [], | ||
"n_atoms": [], | ||
"positions": [], | ||
"box": [], | ||
"idx": [], | ||
"offsets": [], | ||
} | ||
|
||
labels = { | ||
"energy": [], | ||
} | ||
|
||
if self.forces: | ||
labels["forces"] = [] | ||
if self.stress: | ||
labels["stress"] = [] | ||
|
||
for sample in samples: | ||
inp, lab = sample | ||
|
||
inputs["numbers"].append(inp["numbers"]) | ||
inputs["n_atoms"].append(inp["n_atoms"]) | ||
inputs["positions"].append(inp["positions"]) | ||
inputs["box"].append(inp["box"]) | ||
idx, offsets = compute_nl(inp["positions"], inp["box"], self.cutoff) | ||
inputs["idx"].append(idx) | ||
inputs["offsets"].append(offsets) | ||
|
||
labels["energy"].append(lab["energy"]) | ||
if self.forces: | ||
labels["forces"].append(lab["forces"]) | ||
if self.stress: | ||
labels["stress"].append(lab["stress"]) | ||
|
||
max_atoms = np.max(inputs["n_atoms"]) | ||
max_nbrs = np.max([idx.shape[1] for idx in inputs["idx"]]) | ||
|
||
max_atoms = next_power_of_two(max_atoms) | ||
max_nbrs = next_power_of_two(max_nbrs) | ||
|
||
for i in range(len(inputs["n_atoms"])): | ||
inputs["idx"][i], inputs["offsets"][i] = pad_nl( | ||
inputs["idx"][i], inputs["offsets"][i], max_nbrs | ||
) | ||
|
||
zeros_to_add = max_atoms - inputs["numbers"][i].shape[0] | ||
inputs["positions"][i] = np.pad( | ||
inputs["positions"][i], ((0, zeros_to_add), (0, 0)), "constant" | ||
) | ||
inputs["numbers"][i] = np.pad( | ||
inputs["numbers"][i], (0, zeros_to_add), "constant" | ||
).astype(np.int16) | ||
|
||
if "forces" in labels: | ||
labels["forces"][i] = np.pad( | ||
labels["forces"][i], ((0, zeros_to_add), (0, 0)), "constant" | ||
) | ||
|
||
inputs = {k: np.array(v) for k, v in inputs.items()} | ||
labels = {k: np.array(v) for k, v in labels.items()} | ||
return inputs, labels | ||
|
||
|
||
class PerBatchPaddedDataset(InMemoryDataset): | ||
"""Dataset which pads everything (atoms, neighbors) | ||
to the next larges power of two. | ||
This limits the compute wasted due to padding at the (negligible) | ||
cost of some recompilations. | ||
The NL is computed on-the-fly in parallel for `num_workers` of batches. | ||
Does not use tf.data. | ||
|
||
Most performant option for datasets with significantly differently sized systems | ||
(e.g. MP, SPICE). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
atoms_list, | ||
cutoff, | ||
bs, | ||
n_epochs, | ||
n_jit_steps=1, | ||
buffer_size=20, | ||
num_workers=10, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably. I'll update the default. |
||
pos_unit: str = "Ang", | ||
energy_unit: str = "eV", | ||
pre_shuffle=False, | ||
) -> None: | ||
self.cutoff = cutoff | ||
|
||
if n_jit_steps > 1: | ||
raise "PerBatchPaddedDataset is not yet compatible with multi step jit" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure but in general it's better to raise a concrete value, like here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, not doing so was not intended. Thanks for pointing it out |
||
|
||
self.n_jit_steps = n_jit_steps | ||
self.n_epochs = n_epochs | ||
self.n_data = len(atoms_list) | ||
self.batch_size = self.validate_batch_size(bs) | ||
self.pos_unit = pos_unit | ||
|
||
self.buffer_size = buffer_size | ||
self.batch_size = bs | ||
|
||
self.sample_atoms = atoms_list[0] | ||
self.inputs = atoms_to_inputs(atoms_list, pos_unit) | ||
|
||
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) | ||
label_keys = self.labels.keys() | ||
|
||
self.data = list( | ||
zip( | ||
transpose_dict_of_lists(self.inputs), transpose_dict_of_lists(self.labels) | ||
) | ||
) | ||
|
||
forces = "forces" in label_keys | ||
stress = "stress" in label_keys | ||
self.prepare_batch = BatchProcessor(cutoff, forces, stress) | ||
|
||
self.count = 0 | ||
self.max_count = self.n_epochs * self.steps_per_epoch() | ||
self.buffer = deque() | ||
self.num_workers = num_workers | ||
self.process_pool = ProcessPoolExecutor(self.num_workers) | ||
|
||
def enqueue(self, num_batches): | ||
start = self.count * self.batch_size | ||
|
||
dataset_chunks = [ | ||
self.data[start + self.batch_size * i : start + self.batch_size * (i + 1)] | ||
for i in range(0, num_batches) | ||
] | ||
for batch in self.process_pool.map(self.prepare_batch, dataset_chunks): | ||
self.buffer.append(batch) | ||
|
||
self.count += num_batches | ||
|
||
def __iter__(self): | ||
for n in range(self.n_epochs): | ||
self.count = 0 | ||
self.buffer = deque() | ||
|
||
if self.should_shuffle: | ||
shuffle(self.data) | ||
|
||
self.enqueue(min(self.buffer_size, self.n_data // self.batch_size)) | ||
|
||
for i in range(self.steps_per_epoch()): | ||
batch = self.buffer.popleft() | ||
yield batch | ||
|
||
current_buffer_len = len(self.buffer) | ||
space = self.buffer_size - current_buffer_len | ||
|
||
if space >= self.num_workers: | ||
more_data = min(space, self.steps_per_epoch() - self.count) | ||
more_data = max(more_data, 0) | ||
if more_data > 0: | ||
self.enqueue(more_data) | ||
|
||
def shuffle_and_batch(self, sharding): | ||
self.should_shuffle = True | ||
|
||
ds = prefetch_to_single_device( | ||
iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 | ||
) | ||
|
||
return ds | ||
|
||
def batch(self, sharding) -> Iterator[jax.Array]: | ||
self.should_shuffle = False | ||
ds = prefetch_to_single_device( | ||
iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 | ||
) | ||
return ds | ||
|
||
def make_signature(self) -> None: | ||
pass | ||
|
||
|
||
dataset_dict = { | ||
"cached": CachedInMemoryDataset, | ||
"otf": OTFInMemoryDataset, | ||
"pbp": PerBatchPaddedDataset, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd prefer
type
overname
. But I see why you don't want to use that as attribute. Maybemethod
orprocessing
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used name to be consistent with the other uses of discriminative unions in our config.
I think method would be fine but I think we should only use one keyword for this purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change it to processing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reasoning against
name
was, thatdataset:name
reminds me oftrain
,test
orSPICE
etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah I can see how this might be confusing. thanks, I changed it 👍