diff --git a/apax/__init__.py b/apax/__init__.py index de4b9e6e..e3e81556 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -1,4 +1,5 @@ import os +import warnings import jax @@ -8,3 +9,5 @@ from apax.utils.helpers import setup_ase setup_ase() + +warnings.filterwarnings("ignore", message=".*os.fork()*") diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index 16b8db61..abc8c4e1 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -16,8 +16,11 @@ data: #train_data_path: #val_data_path: #test_data_path: + dataset: + processing: cached + shuffle_buffer_size: 1000 + additional_properties_info: {} - ds_type: cached n_train: 1000 n_valid: 100 @@ -31,8 +34,6 @@ data: scale_method: "per_element_force_rms_scale" scale_options: {} - shuffle_buffer_size: 1000 - pos_unit: Ang energy_unit: eV diff --git a/apax/config/train_config.py b/apax/config/train_config.py index bfd4dfe5..57bb7510 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -22,6 +22,66 @@ log = logging.getLogger(__name__) +class DatasetConfig(BaseModel, extra="forbid"): + processing: str + + +class CachedDataset(DatasetConfig, extra="forbid"): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly during the first epoch and stored to disk using + tf.data's cache. + Most performant option for datasets with samples of very similar size. + + Parameters + ---------- + shuffle_buffer_size : int + | Size of the buffer that is shuffled by tf.data. + | Larger values require more RAM. + """ + + processing: Literal["cached"] = "cached" + shuffle_buffer_size: PositiveInt = 1000 + + +class OTFDataset(DatasetConfig, extra="forbid"): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly and fed into a tf.data generator. + Mostly for internal purposes. + + Parameters + ---------- + shuffle_buffer_size : int + | Size of the buffer that is shuffled by tf.data. + | Larger values require more RAM. + """ + + processing: Literal["otf"] = "otf" + shuffle_buffer_size: PositiveInt = 1000 + + +class PBPDatset(DatasetConfig, extra="forbid"): + """Dataset which pads everything (atoms, neighbors) + to the next larges power of two. + This limits the compute wasted due to padding at the (negligible) + cost of some recompilations. + The NL is computed on-the-fly in parallel for `num_workers` of batches. + Does not use tf.data. + + Most performant option for datasets with significantly differently sized systems + (e.g. MP, SPICE). + + Parameters + ---------- + num_workers : int + | Number of batches to be processed in parallel. + """ + + processing: Literal["pbp"] = "pbp" + num_workers: PositiveInt = 10 + + class DataConfig(BaseModel, extra="forbid"): """ Configuration for data loading, preprocessing and training. @@ -59,7 +119,10 @@ class DataConfig(BaseModel, extra="forbid"): directory: str experiment: str - ds_type: Literal["cached", "otf"] = "cached" + dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field( + CachedDataset(processing="cached"), discriminator="processing" + ) + data_path: Optional[str] = None train_data_path: Optional[str] = None val_data_path: Optional[str] = None @@ -69,7 +132,6 @@ class DataConfig(BaseModel, extra="forbid"): n_valid: PositiveInt = 100 batch_size: PositiveInt = 32 valid_batch_size: PositiveInt = 100 - shuffle_buffer_size: PositiveInt = 1000 additional_properties_info: dict[str, str] = {} shift_method: str = "per_element_regression_shift" diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 9d75b4fd..70aaca39 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -1,9 +1,11 @@ import logging +import multiprocessing import uuid from collections import deque +from concurrent.futures import ProcessPoolExecutor from pathlib import Path from random import shuffle -from typing import Dict, Iterator +from typing import Dict, Iterator, Optional import jax import jax.numpy as jnp @@ -11,7 +13,12 @@ import tensorflow as tf from apax.data.preprocessing import compute_nl, prefetch_to_single_device -from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict +from apax.utils.convert import ( + atoms_to_inputs, + atoms_to_labels, + transpose_dict_of_lists, + unit_dict, +) log = logging.getLogger(__name__) @@ -69,24 +76,26 @@ def find_largest_system(inputs, r_max) -> tuple[int]: class InMemoryDataset: + """Baseclass for all datasets which store data in memory.""" + def __init__( self, atoms_list, cutoff, bs, n_epochs, - buffer_size=1000, n_jit_steps=1, pos_unit: str = "Ang", energy_unit: str = "eV", pre_shuffle=False, + shuffle_buffer_size=1000, ignore_labels=False, cache_path=".", ) -> None: self.n_epochs = n_epochs self.cutoff = cutoff self.n_jit_steps = n_jit_steps - self.buffer_size = buffer_size + self.buffer_size = shuffle_buffer_size self.n_data = len(atoms_list) self.batch_size = self.validate_batch_size(bs) self.pos_unit = pos_unit @@ -224,6 +233,13 @@ def cleanup(self): class CachedInMemoryDataset(InMemoryDataset): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly during the first epoch and stored to disk using + tf.data's cache. + Most performant option for datasets with samples of very similar size. + """ + def __iter__(self): while self.count < self.n_data or len(self.buffer) > 0: yield self.buffer.popleft() @@ -283,6 +299,12 @@ def cleanup(self): class OTFInMemoryDataset(InMemoryDataset): + """Dataset which pads everything (atoms, neighbors) + to the largest system in the dataset. + The NL is computed on the fly and fed into a tf.data generator. + Mostly for internal purposes. + """ + def __iter__(self): outer_count = 0 max_iter = self.n_data * self.n_epochs @@ -332,7 +354,204 @@ def batch(self, sharding=None) -> Iterator[jax.Array]: return ds +def next_power_of_two(x): + return 1 << (int(x) - 1).bit_length() + + +class BatchProcessor: + def __init__(self, cutoff, forces=True, stress=False) -> None: + self.cutoff = cutoff + self.forces = forces + self.stress = stress + + def __call__(self, samples: list[dict]): + inputs = { + "numbers": [], + "n_atoms": [], + "positions": [], + "box": [], + "idx": [], + "offsets": [], + } + + labels = { + "energy": [], + } + + if self.forces: + labels["forces"] = [] + if self.stress: + labels["stress"] = [] + + for sample in samples: + inp, lab = sample + + inputs["numbers"].append(inp["numbers"]) + inputs["n_atoms"].append(inp["n_atoms"]) + inputs["positions"].append(inp["positions"]) + inputs["box"].append(inp["box"]) + idx, offsets = compute_nl(inp["positions"], inp["box"], self.cutoff) + inputs["idx"].append(idx) + inputs["offsets"].append(offsets) + + labels["energy"].append(lab["energy"]) + if self.forces: + labels["forces"].append(lab["forces"]) + if self.stress: + labels["stress"].append(lab["stress"]) + + max_atoms = np.max(inputs["n_atoms"]) + max_nbrs = np.max([idx.shape[1] for idx in inputs["idx"]]) + + max_atoms = next_power_of_two(max_atoms) + max_nbrs = next_power_of_two(max_nbrs) + + for i in range(len(inputs["n_atoms"])): + inputs["idx"][i], inputs["offsets"][i] = pad_nl( + inputs["idx"][i], inputs["offsets"][i], max_nbrs + ) + + zeros_to_add = max_atoms - inputs["numbers"][i].shape[0] + inputs["positions"][i] = np.pad( + inputs["positions"][i], ((0, zeros_to_add), (0, 0)), "constant" + ) + inputs["numbers"][i] = np.pad( + inputs["numbers"][i], (0, zeros_to_add), "constant" + ).astype(np.int16) + + if "forces" in labels: + labels["forces"][i] = np.pad( + labels["forces"][i], ((0, zeros_to_add), (0, 0)), "constant" + ) + + inputs = {k: np.array(v) for k, v in inputs.items()} + labels = {k: np.array(v) for k, v in labels.items()} + return inputs, labels + + +class PerBatchPaddedDataset(InMemoryDataset): + """Dataset which pads everything (atoms, neighbors) + to the next larges power of two. + This limits the compute wasted due to padding at the (negligible) + cost of some recompilations. + The NL is computed on-the-fly in parallel for `num_workers` of batches. + Does not use tf.data. + + Most performant option for datasets with significantly differently sized systems + (e.g. MaterialsProject, SPICE). + """ + + def __init__( + self, + atoms_list, + cutoff, + bs, + n_epochs, + n_jit_steps=1, + num_workers: Optional[int] = None, + pos_unit: str = "Ang", + energy_unit: str = "eV", + pre_shuffle=False, + ) -> None: + self.cutoff = cutoff + + if n_jit_steps > 1: + raise NotImplementedError( + "PerBatchPaddedDataset is not yet compatible with multi step jit" + ) + + self.n_jit_steps = n_jit_steps + self.n_epochs = n_epochs + self.n_data = len(atoms_list) + self.batch_size = self.validate_batch_size(bs) + self.pos_unit = pos_unit + + if num_workers: + self.num_workers = num_workers + else: + self.num_workers = multiprocessing.cpu_count() + self.buffer_size = num_workers * 2 + self.batch_size = bs + + self.sample_atoms = atoms_list[0] + self.inputs = atoms_to_inputs(atoms_list, pos_unit) + + self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) + label_keys = self.labels.keys() + + self.data = list( + zip( + transpose_dict_of_lists(self.inputs), transpose_dict_of_lists(self.labels) + ) + ) + + forces = "forces" in label_keys + stress = "stress" in label_keys + self.prepare_batch = BatchProcessor(cutoff, forces, stress) + + self.count = 0 + self.max_count = self.n_epochs * self.steps_per_epoch() + self.buffer = deque() + + self.process_pool = ProcessPoolExecutor(self.num_workers) + + def enqueue(self, num_batches): + start = self.count * self.batch_size + + dataset_chunks = [ + self.data[start + self.batch_size * i : start + self.batch_size * (i + 1)] + for i in range(0, num_batches) + ] + for batch in self.process_pool.map(self.prepare_batch, dataset_chunks): + self.buffer.append(batch) + + self.count += num_batches + + def __iter__(self): + for n in range(self.n_epochs): + self.count = 0 + self.buffer = deque() + + if self.should_shuffle: + shuffle(self.data) + + self.enqueue(min(self.buffer_size, self.n_data // self.batch_size)) + + for i in range(self.steps_per_epoch()): + batch = self.buffer.popleft() + yield batch + + current_buffer_len = len(self.buffer) + space = self.buffer_size - current_buffer_len + + if space >= self.num_workers: + more_data = min(space, self.steps_per_epoch() - self.count) + more_data = max(more_data, 0) + if more_data > 0: + self.enqueue(more_data) + + def shuffle_and_batch(self, sharding): + self.should_shuffle = True + + ds = prefetch_to_single_device( + iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) + + return ds + + def batch(self, sharding) -> Iterator[jax.Array]: + self.should_shuffle = False + ds = prefetch_to_single_device( + iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) + return ds + + def make_signature(self) -> None: + pass + + dataset_dict = { "cached": CachedInMemoryDataset, "otf": OTFInMemoryDataset, + "pbp": PerBatchPaddedDataset, } diff --git a/apax/train/run.py b/apax/train/run.py index fec30fa7..a5263a25 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -99,19 +99,24 @@ def initialize_datasets(config: Config): train_raw_ds, val_raw_ds = load_data_files(config.data) - Dataset = dataset_dict[config.data.ds_type] + Dataset = dataset_dict[config.data.dataset.processing] + + dataset_kwargs = dict(config.data.dataset) + processing = dataset_kwargs.pop("processing") + + if processing == "cached": + dataset_kwargs["cache_path"] = config.data.model_version_path train_ds = Dataset( train_raw_ds, config.model.r_max, config.data.batch_size, config.n_epochs, - config.data.shuffle_buffer_size, config.n_jitted_steps, pos_unit=config.data.pos_unit, energy_unit=config.data.energy_unit, pre_shuffle=True, - cache_path=config.data.model_version_path, + **dataset_kwargs, ) val_ds = Dataset( val_raw_ds, @@ -120,7 +125,7 @@ def initialize_datasets(config: Config): config.n_epochs, pos_unit=config.data.pos_unit, energy_unit=config.data.energy_unit, - cache_path=config.data.model_version_path, + **dataset_kwargs, ) ds_stats = compute_scale_shift_parameters( train_ds.inputs, diff --git a/apax/utils/convert.py b/apax/utils/convert.py index a7d4cc45..af6e5335 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -153,3 +153,14 @@ def atoms_to_labels( labels = prune_dict(labels) return labels + + +def transpose_dict_of_lists(dict_of_lists: dict): + list_of_dicts = [] + keys = list(dict_of_lists.keys()) + + for i in range(len(dict_of_lists[keys[0]])): + data = {k: dict_of_lists[k][i] for k in keys} + list_of_dicts.append(data) + + return list_of_dicts diff --git a/tests/integration_tests/bal/config.yaml b/tests/integration_tests/bal/config.yaml index 60a9c0b9..98c058c1 100644 --- a/tests/integration_tests/bal/config.yaml +++ b/tests/integration_tests/bal/config.yaml @@ -9,7 +9,6 @@ data: n_valid: 2 batch_size: 2 valid_batch_size: 2 - shuffle_buffer_size: 4 model: nn: [32,32] diff --git a/tests/integration_tests/md/config.yaml b/tests/integration_tests/md/config.yaml index 2371a2fe..feef8938 100644 --- a/tests/integration_tests/md/config.yaml +++ b/tests/integration_tests/md/config.yaml @@ -8,7 +8,6 @@ data: n_valid: 2 batch_size: 2 valid_batch_size: 2 - shuffle_buffer_size: 4 model: nn: [32,32] diff --git a/tests/unit_tests/utils/test_convert.py b/tests/unit_tests/utils/test_convert.py index e69de29b..7c5f245c 100644 --- a/tests/unit_tests/utils/test_convert.py +++ b/tests/unit_tests/utils/test_convert.py @@ -0,0 +1,22 @@ +import numpy as np + +from apax.utils.convert import transpose_dict_of_lists + + +def test_transpose_dict_of_lists(): + + b = np.arange(8).reshape((4, 2)) + a = [0, 1, 2, 3] + inputs = { + "a": a, + "b": b, + } + + out = transpose_dict_of_lists(inputs) + assert len(out) == len(a) + + for ii, entry in enumerate(out): + assert "a" in entry.keys() + assert "b" in entry.keys() + assert entry["a"] == a[ii] + assert np.all(entry["b"] == b[ii])