From e2beaf6219486f7d46ea8349dc048e38e766344a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 3 Jun 2024 12:37:07 +0200 Subject: [PATCH 01/14] outline of PerBatchPaddedDataset --- apax/data/input_pipeline.py | 93 +++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 9d75b4fd..8939eb6e 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -332,6 +332,99 @@ def batch(self, sharding=None) -> Iterator[jax.Array]: return ds + + +class PerBatchPaddedDataset(InMemoryDataset): + def __init__( + self, + atoms_list, + cutoff, + bs, + n_epochs, + buffer_size=1000, + n_jit_steps=1, + pos_unit: str = "Ang", + energy_unit: str = "eV", + pre_shuffle=False, + ignore_labels=False, + cache_path=".", + ) -> None: + self.n_epochs = n_epochs + self.cutoff = cutoff + self.n_jit_steps = n_jit_steps + self.buffer_size = buffer_size + self.n_data = len(atoms_list) + self.batch_size = self.validate_batch_size(bs) + self.pos_unit = pos_unit + + if pre_shuffle: + shuffle(atoms_list) + self.sample_atoms = atoms_list[0] + self.inputs = atoms_to_inputs(atoms_list, pos_unit) + + # max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff) + # self.max_atoms = max_atoms + # self.max_nbrs = max_nbrs + if atoms_list[0].calc and not ignore_labels: + self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) + else: + self.labels = None + + self.count = 0 + self.buffer = deque() + self.file = Path(cache_path) / str(uuid.uuid4()) + + self.enqueue(min(self.buffer_size, self.n_data)) + + + def prepare_data(self, i): + inputs = {k: v[i] for k, v in self.inputs.items()} + idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) + inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) + + zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] + inputs["positions"] = np.pad( + inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant" + ) + inputs["numbers"] = np.pad( + inputs["numbers"], (0, zeros_to_add), "constant" + ).astype(np.int16) + + if not self.labels: + return inputs + + labels = {k: v[i] for k, v in self.labels.items()} + if "forces" in labels: + labels["forces"] = np.pad( + labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" + ) + inputs = {k: tf.constant(v) for k, v in inputs.items()} + labels = {k: tf.constant(v) for k, v in labels.items()} + return (inputs, labels) + + def enqueue(self, num_elements): + for _ in range(num_elements): + data = self.prepare_data(self.count) + self.buffer.append(data) + self.count += 1 + + def make_signature(self) -> None: + pass + + def __iter__(self): + raise NotImplementedError + + def shuffle_and_batch(self): + raise NotImplementedError + + def batch(self) -> Iterator[jax.Array]: + raise NotImplementedError + + + + + + dataset_dict = { "cached": CachedInMemoryDataset, "otf": OTFInMemoryDataset, From 80ca14df305c2aaf5e8380906d2c043db3b60b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 19:47:26 +0200 Subject: [PATCH 02/14] added multiprocessing BatchProcessor class --- apax/config/train_config.py | 2 +- apax/data/input_pipeline.py | 192 +++++++++++++++++++++++++++--------- apax/utils/convert.py | 11 +++ 3 files changed, 156 insertions(+), 49 deletions(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index d58b61b2..63deda53 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -59,7 +59,7 @@ class DataConfig(BaseModel, extra="forbid"): directory: str experiment: str - ds_type: Literal["cached", "otf"] = "cached" + ds_type: Literal["cached", "otf", "pbp"] = "cached" data_path: Optional[str] = None train_data_path: Optional[str] = None val_data_path: Optional[str] = None diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 8939eb6e..49da4c20 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -1,6 +1,7 @@ import logging import uuid from collections import deque +from concurrent.futures import ProcessPoolExecutor from pathlib import Path from random import shuffle from typing import Dict, Iterator @@ -11,7 +12,12 @@ import tensorflow as tf from apax.data.preprocessing import compute_nl, prefetch_to_single_device -from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict +from apax.utils.convert import ( + atoms_to_inputs, + atoms_to_labels, + transpose_dict_of_lists, + unit_dict, +) log = logging.getLogger(__name__) @@ -332,6 +338,79 @@ def batch(self, sharding=None) -> Iterator[jax.Array]: return ds +def next_power_of_two(x): + return 1 << (int(x) - 1).bit_length() + + +class BatchProcessor: + def __init__(self, cutoff, forces=True, stress=False) -> None: + self.cutoff = cutoff + self.forces = forces + self.stress = stress + + def __call__(self, samples: list[dict]): + inputs = { + "numbers": [], + "n_atoms": [], + "positions": [], + "box": [], + "idx": [], + "offsets": [], + } + + labels = { + "energy": [], + } + + if self.forces: + labels["forces"] = [] + if self.stress: + labels["stress"] = [] + + for sample in samples: + inp, lab = sample + + inputs["numbers"].append(inp["numbers"]) + inputs["n_atoms"].append(inp["n_atoms"]) + inputs["positions"].append(inp["positions"]) + inputs["box"].append(inp["box"]) + idx, offsets = compute_nl(inp["positions"], inp["box"], self.cutoff) + inputs["idx"].append(idx) + inputs["offsets"].append(offsets) + + labels["energy"].append(lab["energy"]) + if self.forces: + labels["forces"].append(lab["forces"]) + if self.stress: + labels["stress"].append(lab["stress"]) + + max_atoms = np.max(inputs["n_atoms"]) + max_nbrs = np.max([idx.shape[1] for idx in inputs["idx"]]) + + max_atoms = next_power_of_two(max_atoms) + max_nbrs = next_power_of_two(max_nbrs) + + for i in range(len(inputs["n_atoms"])): + inputs["idx"][i], inputs["offsets"][i] = pad_nl( + inputs["idx"][i], inputs["offsets"][i], max_nbrs + ) + + zeros_to_add = max_atoms - inputs["numbers"][i].shape[0] + inputs["positions"][i] = np.pad( + inputs["positions"][i], ((0, zeros_to_add), (0, 0)), "constant" + ) + inputs["numbers"][i] = np.pad( + inputs["numbers"][i], (0, zeros_to_add), "constant" + ).astype(np.int16) + + if "forces" in labels: + labels["forces"][i] = np.pad( + labels["forces"][i], ((0, zeros_to_add), (0, 0)), "constant" + ) + + inputs = {k: np.array(v) for k, v in inputs.items()} + labels = {k: np.array(v) for k, v in labels.items()} + return inputs, labels class PerBatchPaddedDataset(InMemoryDataset): @@ -341,7 +420,7 @@ def __init__( cutoff, bs, n_epochs, - buffer_size=1000, + buffer_size=20, n_jit_steps=1, pos_unit: str = "Ang", energy_unit: str = "eV", @@ -349,83 +428,100 @@ def __init__( ignore_labels=False, cache_path=".", ) -> None: - self.n_epochs = n_epochs self.cutoff = cutoff + self.n_jit_steps = n_jit_steps - self.buffer_size = buffer_size + self.n_epochs = n_epochs self.n_data = len(atoms_list) self.batch_size = self.validate_batch_size(bs) self.pos_unit = pos_unit - if pre_shuffle: - shuffle(atoms_list) + self.buffer_size = buffer_size + self.batch_size = bs + self.sample_atoms = atoms_list[0] self.inputs = atoms_to_inputs(atoms_list, pos_unit) - # max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff) - # self.max_atoms = max_atoms - # self.max_nbrs = max_nbrs if atoms_list[0].calc and not ignore_labels: self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) else: self.labels = None + label_keys = self.labels.keys() + forces = False + stress = False + if "forces" in label_keys: + forces = True + if "stress" in label_keys: + stress = True + self.prepare_batch = BatchProcessor(cutoff, forces, stress) + + self.data = list( + zip( + transpose_dict_of_lists(self.inputs), transpose_dict_of_lists(self.labels) + ) + ) + self.count = 0 + self.max_count = self.n_epochs * self.steps_per_epoch() self.buffer = deque() - self.file = Path(cache_path) / str(uuid.uuid4()) + self.n_workers = 10 + self.process_pool = ProcessPoolExecutor(self.n_workers) - self.enqueue(min(self.buffer_size, self.n_data)) - - - def prepare_data(self, i): - inputs = {k: v[i] for k, v in self.inputs.items()} - idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) - inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) - - zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] - inputs["positions"] = np.pad( - inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant" - ) - inputs["numbers"] = np.pad( - inputs["numbers"], (0, zeros_to_add), "constant" - ).astype(np.int16) + def enqueue(self, num_batches): + start = self.count * self.batch_size - if not self.labels: - return inputs + dataset_chunks = [ + self.data[start + self.batch_size * i : start + self.batch_size * (i + 1)] + for i in range(0, num_batches) + ] + for batch in self.process_pool.map(self.prepare_batch, dataset_chunks): + self.buffer.append(batch) - labels = {k: v[i] for k, v in self.labels.items()} - if "forces" in labels: - labels["forces"] = np.pad( - labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" - ) - inputs = {k: tf.constant(v) for k, v in inputs.items()} - labels = {k: tf.constant(v) for k, v in labels.items()} - return (inputs, labels) + self.count += num_batches - def enqueue(self, num_elements): - for _ in range(num_elements): - data = self.prepare_data(self.count) - self.buffer.append(data) - self.count += 1 + def __iter__(self): + for n in range(self.n_epochs): + self.count = 0 + if self.should_shuffle: + shuffle(self.data) + self.buffer = deque() + self.enqueue(min(self.buffer_size, self.n_data // self.batch_size)) - def make_signature(self) -> None: - pass + for i in range(self.steps_per_epoch()): + batch = self.buffer.popleft() + yield batch - def __iter__(self): - raise NotImplementedError + current_buffer_len = len(self.buffer) + space = self.buffer_size - current_buffer_len - def shuffle_and_batch(self): - raise NotImplementedError + if space >= self.n_workers: + more_data = min(space, self.steps_per_epoch() - self.count) + more_data = max(more_data, 0) + if more_data > 0: + self.enqueue(more_data) - def batch(self) -> Iterator[jax.Array]: - raise NotImplementedError + def shuffle_and_batch(self, sharding): + self.should_shuffle = True + ds = prefetch_to_single_device( + iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) + return ds + def batch(self, sharding) -> Iterator[jax.Array]: + ds = prefetch_to_single_device( + iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) + return ds + def make_signature(self) -> None: + pass dataset_dict = { "cached": CachedInMemoryDataset, "otf": OTFInMemoryDataset, + "pbp": PerBatchPaddedDataset, } diff --git a/apax/utils/convert.py b/apax/utils/convert.py index a7d4cc45..af6e5335 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -153,3 +153,14 @@ def atoms_to_labels( labels = prune_dict(labels) return labels + + +def transpose_dict_of_lists(dict_of_lists: dict): + list_of_dicts = [] + keys = list(dict_of_lists.keys()) + + for i in range(len(dict_of_lists[keys[0]])): + data = {k: dict_of_lists[k][i] for k in keys} + list_of_dicts.append(data) + + return list_of_dicts From 1c534c671062fae253872a8f522b36c5d5c068b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 20:48:24 +0200 Subject: [PATCH 03/14] updated config file --- apax/cli/templates/train_config_full.yaml | 8 +++-- apax/config/train_config.py | 26 +++++++++++++-- apax/data/input_pipeline.py | 40 +++++++++++------------ apax/train/run.py | 10 +++--- 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index 16b8db61..027e5c68 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -16,8 +16,12 @@ data: #train_data_path: #val_data_path: #test_data_path: + dataset: + name: cached + shuffle_buffer_size: 1000 + cache_path: "." + additional_properties_info: {} - ds_type: cached n_train: 1000 n_valid: 100 @@ -31,8 +35,6 @@ data: scale_method: "per_element_force_rms_scale" scale_options: {} - shuffle_buffer_size: 1000 - pos_unit: Ang energy_unit: eV diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 63deda53..1d97bdff 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -22,6 +22,26 @@ log = logging.getLogger(__name__) +class DatasetConfig(BaseModel, extra="forbid"): + name: str + + +class CachedDataset(DatasetConfig, extra="forbid"): + name: Literal["cached"] = "cached" + shuffle_buffer_size: PositiveInt = 1000 + cache_path: str = "." + + +class OTFDataset(DatasetConfig, extra="forbid"): + name: Literal["otf"] = "otf" + shuffle_buffer_size: PositiveInt = 1000 + + +class PBPDatset(DatasetConfig, extra="forbid"): + name: Literal["pbp"] = "pbp" + num_workers: PositiveInt = 10 + + class DataConfig(BaseModel, extra="forbid"): """ Configuration for data loading, preprocessing and training. @@ -59,7 +79,10 @@ class DataConfig(BaseModel, extra="forbid"): directory: str experiment: str - ds_type: Literal["cached", "otf", "pbp"] = "cached" + dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field( + CachedDataset(name="cached"), discriminator="name" + ) + data_path: Optional[str] = None train_data_path: Optional[str] = None val_data_path: Optional[str] = None @@ -69,7 +92,6 @@ class DataConfig(BaseModel, extra="forbid"): n_valid: PositiveInt = 100 batch_size: PositiveInt = 32 valid_batch_size: PositiveInt = 100 - shuffle_buffer_size: PositiveInt = 1000 additional_properties_info: dict[str, str] = {} shift_method: str = "per_element_regression_shift" diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 49da4c20..35bcbe73 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -81,18 +81,18 @@ def __init__( cutoff, bs, n_epochs, - buffer_size=1000, n_jit_steps=1, pos_unit: str = "Ang", energy_unit: str = "eV", pre_shuffle=False, + shuffle_buffer_size=1000, ignore_labels=False, cache_path=".", ) -> None: self.n_epochs = n_epochs self.cutoff = cutoff self.n_jit_steps = n_jit_steps - self.buffer_size = buffer_size + self.buffer_size = shuffle_buffer_size self.n_data = len(atoms_list) self.batch_size = self.validate_batch_size(bs) self.pos_unit = pos_unit @@ -420,16 +420,18 @@ def __init__( cutoff, bs, n_epochs, - buffer_size=20, n_jit_steps=1, + buffer_size=20, + num_workers=10, pos_unit: str = "Ang", energy_unit: str = "eV", pre_shuffle=False, - ignore_labels=False, - cache_path=".", ) -> None: self.cutoff = cutoff + if n_jit_steps > 1: + raise "PerBatchPaddedDataset is not yet compatible with multi step jit" + self.n_jit_steps = n_jit_steps self.n_epochs = n_epochs self.n_data = len(atoms_list) @@ -442,19 +444,8 @@ def __init__( self.sample_atoms = atoms_list[0] self.inputs = atoms_to_inputs(atoms_list, pos_unit) - if atoms_list[0].calc and not ignore_labels: - self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) - else: - self.labels = None - + self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) label_keys = self.labels.keys() - forces = False - stress = False - if "forces" in label_keys: - forces = True - if "stress" in label_keys: - stress = True - self.prepare_batch = BatchProcessor(cutoff, forces, stress) self.data = list( zip( @@ -462,11 +453,15 @@ def __init__( ) ) + forces = "forces" in label_keys + stress = "stress" in label_keys + self.prepare_batch = BatchProcessor(cutoff, forces, stress) + self.count = 0 self.max_count = self.n_epochs * self.steps_per_epoch() self.buffer = deque() - self.n_workers = 10 - self.process_pool = ProcessPoolExecutor(self.n_workers) + self.num_workers = num_workers + self.process_pool = ProcessPoolExecutor(self.num_workers) def enqueue(self, num_batches): start = self.count * self.batch_size @@ -483,9 +478,11 @@ def enqueue(self, num_batches): def __iter__(self): for n in range(self.n_epochs): self.count = 0 + self.buffer = deque() + if self.should_shuffle: shuffle(self.data) - self.buffer = deque() + self.enqueue(min(self.buffer_size, self.n_data // self.batch_size)) for i in range(self.steps_per_epoch()): @@ -495,7 +492,7 @@ def __iter__(self): current_buffer_len = len(self.buffer) space = self.buffer_size - current_buffer_len - if space >= self.n_workers: + if space >= self.num_workers: more_data = min(space, self.steps_per_epoch() - self.count) more_data = max(more_data, 0) if more_data > 0: @@ -511,6 +508,7 @@ def shuffle_and_batch(self, sharding): return ds def batch(self, sharding) -> Iterator[jax.Array]: + self.should_shuffle = False ds = prefetch_to_single_device( iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 ) diff --git a/apax/train/run.py b/apax/train/run.py index fec30fa7..4670b07f 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -99,19 +99,21 @@ def initialize_datasets(config: Config): train_raw_ds, val_raw_ds = load_data_files(config.data) - Dataset = dataset_dict[config.data.ds_type] + Dataset = dataset_dict[config.data.dataset.name] + + dataset_kwargs = dict(config.data.dataset) + dataset_kwargs.pop("name") train_ds = Dataset( train_raw_ds, config.model.r_max, config.data.batch_size, config.n_epochs, - config.data.shuffle_buffer_size, config.n_jitted_steps, pos_unit=config.data.pos_unit, energy_unit=config.data.energy_unit, pre_shuffle=True, - cache_path=config.data.model_version_path, + **dataset_kwargs, ) val_ds = Dataset( val_raw_ds, @@ -120,7 +122,7 @@ def initialize_datasets(config: Config): config.n_epochs, pos_unit=config.data.pos_unit, energy_unit=config.data.energy_unit, - cache_path=config.data.model_version_path, + **dataset_kwargs, ) ds_stats = compute_scale_shift_parameters( train_ds.inputs, From 7cc567e9026665f4cfbcab8e211612ec168d6208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 21:44:14 +0200 Subject: [PATCH 04/14] updated test configs --- tests/integration_tests/bal/config.yaml | 1 - tests/integration_tests/md/config.yaml | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/integration_tests/bal/config.yaml b/tests/integration_tests/bal/config.yaml index 60a9c0b9..98c058c1 100644 --- a/tests/integration_tests/bal/config.yaml +++ b/tests/integration_tests/bal/config.yaml @@ -9,7 +9,6 @@ data: n_valid: 2 batch_size: 2 valid_batch_size: 2 - shuffle_buffer_size: 4 model: nn: [32,32] diff --git a/tests/integration_tests/md/config.yaml b/tests/integration_tests/md/config.yaml index 2371a2fe..feef8938 100644 --- a/tests/integration_tests/md/config.yaml +++ b/tests/integration_tests/md/config.yaml @@ -8,7 +8,6 @@ data: n_valid: 2 batch_size: 2 valid_batch_size: 2 - shuffle_buffer_size: 4 model: nn: [32,32] From b6e09d6e89a6421015e9b871855bb41fa703d112 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 21:57:48 +0200 Subject: [PATCH 05/14] added doc strings to datastes --- apax/config/train_config.py | 41 +++++++++++++++++++++++++++++++++++++ apax/data/input_pipeline.py | 26 +++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 1d97bdff..2181071f 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -27,17 +27,58 @@ class DatasetConfig(BaseModel, extra="forbid"): class CachedDataset(DatasetConfig, extra="forbid"): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly during the first epoch and stored to disk using + tf.data's cache. + Most performant option for datasets with samples of very similar size. + + Parameters + ---------- + shuffle_buffer_size : int + | Size of the buffer that is shuffled by tf.data. + | Larger values require more RAM. + """ + name: Literal["cached"] = "cached" shuffle_buffer_size: PositiveInt = 1000 cache_path: str = "." class OTFDataset(DatasetConfig, extra="forbid"): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly and fed into a tf.data generator. + Mostly for internal purposes. + + Parameters + ---------- + shuffle_buffer_size : int + | Size of the buffer that is shuffled by tf.data. + | Larger values require more RAM. + """ + name: Literal["otf"] = "otf" shuffle_buffer_size: PositiveInt = 1000 class PBPDatset(DatasetConfig, extra="forbid"): + """Dataset which pads everything (atoms, neighbors) + to the next larges power of two. + This limits the compute wasted due to padding at the (negligible) + cost of some recompilations. + The NL is computed on-the-fly in parallel for `num_workers` of batches. + Does not use tf.data. + + Most performant option for datasets with significantly differently sized systems + (e.g. MP, SPICE). + + Parameters + ---------- + num_workers : int + | Number of batches to be processed in parallel. + """ + name: Literal["pbp"] = "pbp" num_workers: PositiveInt = 10 diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 35bcbe73..0162521a 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -75,6 +75,8 @@ def find_largest_system(inputs, r_max) -> tuple[int]: class InMemoryDataset: + """Baseclass for all datasets which store data in memory.""" + def __init__( self, atoms_list, @@ -230,6 +232,13 @@ def cleanup(self): class CachedInMemoryDataset(InMemoryDataset): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly during the first epoch and stored to disk using + tf.data's cache. + Most performant option for datasets with samples of very similar size. + """ + def __iter__(self): while self.count < self.n_data or len(self.buffer) > 0: yield self.buffer.popleft() @@ -289,6 +298,12 @@ def cleanup(self): class OTFInMemoryDataset(InMemoryDataset): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly and fed into a tf.data generator. + Mostly for internal purposes. + """ + def __iter__(self): outer_count = 0 max_iter = self.n_data * self.n_epochs @@ -414,6 +429,17 @@ def __call__(self, samples: list[dict]): class PerBatchPaddedDataset(InMemoryDataset): + """Dataset which pads everything (atoms, neighbors) + to the next larges power of two. + This limits the compute wasted due to padding at the (negligible) + cost of some recompilations. + The NL is computed on-the-fly in parallel for `num_workers` of batches. + Does not use tf.data. + + Most performant option for datasets with significantly differently sized systems + (e.g. MP, SPICE). + """ + def __init__( self, atoms_list, From bfb1152f4407b083171a2a0f52c369790f440dcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 22:05:45 +0200 Subject: [PATCH 06/14] added warning filter for os.fork() --- apax/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/apax/__init__.py b/apax/__init__.py index de4b9e6e..e3e81556 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -1,4 +1,5 @@ import os +import warnings import jax @@ -8,3 +9,5 @@ from apax.utils.helpers import setup_ase setup_ase() + +warnings.filterwarnings("ignore", message=".*os.fork()*") From 94db42220a153fb283301f0922a87441ce3e78da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 6 Jun 2024 09:38:02 +0200 Subject: [PATCH 07/14] removed cahce dir from input config --- apax/config/train_config.py | 1 - apax/train/run.py | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 2181071f..16084fce 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -42,7 +42,6 @@ class CachedDataset(DatasetConfig, extra="forbid"): name: Literal["cached"] = "cached" shuffle_buffer_size: PositiveInt = 1000 - cache_path: str = "." class OTFDataset(DatasetConfig, extra="forbid"): diff --git a/apax/train/run.py b/apax/train/run.py index 4670b07f..9067e943 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -102,7 +102,10 @@ def initialize_datasets(config: Config): Dataset = dataset_dict[config.data.dataset.name] dataset_kwargs = dict(config.data.dataset) - dataset_kwargs.pop("name") + name = dataset_kwargs.pop("name") + + if name == "cached": + dataset_kwargs["cache_path"] = config.data.model_version_path train_ds = Dataset( train_raw_ds, From 49054190342da2b7f411f15d1aa72ff82440ffa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 6 Jun 2024 14:27:54 +0200 Subject: [PATCH 08/14] removed cache_path from full config --- apax/cli/templates/train_config_full.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index 027e5c68..8bf52430 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -19,7 +19,6 @@ data: dataset: name: cached shuffle_buffer_size: 1000 - cache_path: "." additional_properties_info: {} From c4b7bb29c50045bb88856d083468be55686d08f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 7 Jun 2024 10:07:35 +0200 Subject: [PATCH 09/14] made num workers default to num cpus in system --- apax/data/input_pipeline.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 0162521a..c9dc5789 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -2,9 +2,10 @@ import uuid from collections import deque from concurrent.futures import ProcessPoolExecutor +import multiprocessing from pathlib import Path from random import shuffle -from typing import Dict, Iterator +from typing import Dict, Iterator, Optional import jax import jax.numpy as jnp @@ -437,7 +438,7 @@ class PerBatchPaddedDataset(InMemoryDataset): Does not use tf.data. Most performant option for datasets with significantly differently sized systems - (e.g. MP, SPICE). + (e.g. MaterialsProject, SPICE). """ def __init__( @@ -447,8 +448,7 @@ def __init__( bs, n_epochs, n_jit_steps=1, - buffer_size=20, - num_workers=10, + num_workers: Optional[int]=None, pos_unit: str = "Ang", energy_unit: str = "eV", pre_shuffle=False, @@ -464,7 +464,11 @@ def __init__( self.batch_size = self.validate_batch_size(bs) self.pos_unit = pos_unit - self.buffer_size = buffer_size + if num_workers: + self.num_workers = num_workers + else: + self.num_workers = multiprocessing.cpu_count() + self.buffer_size = num_workers * 2 self.batch_size = bs self.sample_atoms = atoms_list[0] @@ -486,7 +490,7 @@ def __init__( self.count = 0 self.max_count = self.n_epochs * self.steps_per_epoch() self.buffer = deque() - self.num_workers = num_workers + self.process_pool = ProcessPoolExecutor(self.num_workers) def enqueue(self, num_batches): From 24f60d73280656fcd6889c9787434b50ab49c114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 7 Jun 2024 10:11:35 +0200 Subject: [PATCH 10/14] set default n workers to num cpus --- apax/data/input_pipeline.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index c9dc5789..70aaca39 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -1,8 +1,8 @@ import logging +import multiprocessing import uuid from collections import deque from concurrent.futures import ProcessPoolExecutor -import multiprocessing from pathlib import Path from random import shuffle from typing import Dict, Iterator, Optional @@ -448,7 +448,7 @@ def __init__( bs, n_epochs, n_jit_steps=1, - num_workers: Optional[int]=None, + num_workers: Optional[int] = None, pos_unit: str = "Ang", energy_unit: str = "eV", pre_shuffle=False, @@ -456,7 +456,9 @@ def __init__( self.cutoff = cutoff if n_jit_steps > 1: - raise "PerBatchPaddedDataset is not yet compatible with multi step jit" + raise NotImplementedError( + "PerBatchPaddedDataset is not yet compatible with multi step jit" + ) self.n_jit_steps = n_jit_steps self.n_epochs = n_epochs @@ -490,7 +492,7 @@ def __init__( self.count = 0 self.max_count = self.n_epochs * self.steps_per_epoch() self.buffer = deque() - + self.process_pool = ProcessPoolExecutor(self.num_workers) def enqueue(self, num_batches): From a89e7287837b1eec68e6a3851ae1877b1ff3e9cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 7 Jun 2024 10:14:07 +0200 Subject: [PATCH 11/14] updated name key to processing --- apax/cli/templates/train_config_full.yaml | 2 +- apax/config/train_config.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index 8bf52430..abc8c4e1 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -17,7 +17,7 @@ data: #val_data_path: #test_data_path: dataset: - name: cached + processing: cached shuffle_buffer_size: 1000 additional_properties_info: {} diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 16084fce..c94eb233 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -23,7 +23,7 @@ class DatasetConfig(BaseModel, extra="forbid"): - name: str + processing: str class CachedDataset(DatasetConfig, extra="forbid"): @@ -40,7 +40,7 @@ class CachedDataset(DatasetConfig, extra="forbid"): | Larger values require more RAM. """ - name: Literal["cached"] = "cached" + processing: Literal["cached"] = "cached" shuffle_buffer_size: PositiveInt = 1000 @@ -57,7 +57,7 @@ class OTFDataset(DatasetConfig, extra="forbid"): | Larger values require more RAM. """ - name: Literal["otf"] = "otf" + processing: Literal["otf"] = "otf" shuffle_buffer_size: PositiveInt = 1000 @@ -78,7 +78,7 @@ class PBPDatset(DatasetConfig, extra="forbid"): | Number of batches to be processed in parallel. """ - name: Literal["pbp"] = "pbp" + processing: Literal["pbp"] = "pbp" num_workers: PositiveInt = 10 @@ -120,7 +120,7 @@ class DataConfig(BaseModel, extra="forbid"): directory: str experiment: str dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field( - CachedDataset(name="cached"), discriminator="name" + CachedDataset(processing="cached"), discriminator="processing" ) data_path: Optional[str] = None From 93cd227570c0bd5bfeef50b52e526356ada302f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 7 Jun 2024 10:37:29 +0200 Subject: [PATCH 12/14] added unit test for transpose dict of lsit --- apax/train/run.py | 6 +++--- tests/unit_tests/utils/test_convert.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/apax/train/run.py b/apax/train/run.py index 9067e943..a5263a25 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -99,12 +99,12 @@ def initialize_datasets(config: Config): train_raw_ds, val_raw_ds = load_data_files(config.data) - Dataset = dataset_dict[config.data.dataset.name] + Dataset = dataset_dict[config.data.dataset.processing] dataset_kwargs = dict(config.data.dataset) - name = dataset_kwargs.pop("name") + processing = dataset_kwargs.pop("processing") - if name == "cached": + if processing == "cached": dataset_kwargs["cache_path"] = config.data.model_version_path train_ds = Dataset( diff --git a/tests/unit_tests/utils/test_convert.py b/tests/unit_tests/utils/test_convert.py index e69de29b..40d9714e 100644 --- a/tests/unit_tests/utils/test_convert.py +++ b/tests/unit_tests/utils/test_convert.py @@ -0,0 +1,21 @@ +import numpy as np + +from apax.utils.convert import transpose_dict_of_lists + +def test_transpose_dict_of_lists(): + + b = np.arange(8).reshape((4,2)) + a = [0,1,2,3] + inputs = { + "a": a, + "b": b, + } + + out = transpose_dict_of_lists(inputs) + assert len(out) == len(a) + + for ii, entry in enumerate(out): + assert "a" in entry.keys() + assert "b" in entry.keys() + assert entry["a"] == a[ii] + assert np.all(entry["b"] == b[ii]) From dcb44d63328a062daddf4fffd9a2c01410968a5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:47:23 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unit_tests/utils/test_convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/utils/test_convert.py b/tests/unit_tests/utils/test_convert.py index 40d9714e..980be78a 100644 --- a/tests/unit_tests/utils/test_convert.py +++ b/tests/unit_tests/utils/test_convert.py @@ -3,7 +3,7 @@ from apax.utils.convert import transpose_dict_of_lists def test_transpose_dict_of_lists(): - + b = np.arange(8).reshape((4,2)) a = [0,1,2,3] inputs = { From d11dbfd61e32e8c106a2f7d4806bb4b2a3f091ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:53:27 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unit_tests/utils/test_convert.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/utils/test_convert.py b/tests/unit_tests/utils/test_convert.py index 980be78a..7c5f245c 100644 --- a/tests/unit_tests/utils/test_convert.py +++ b/tests/unit_tests/utils/test_convert.py @@ -2,10 +2,11 @@ from apax.utils.convert import transpose_dict_of_lists + def test_transpose_dict_of_lists(): - b = np.arange(8).reshape((4,2)) - a = [0,1,2,3] + b = np.arange(8).reshape((4, 2)) + a = [0, 1, 2, 3] inputs = { "a": a, "b": b,