From 986315401017cc20240dc7015e9fc54acb584602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 18 Mar 2024 15:47:29 +0100 Subject: [PATCH 1/9] successfull sharding sketch --- apax/__init__.py | 20 ++++++++++++++++++++ apax/data/input_pipeline.py | 16 ++++++++-------- apax/data/preprocessing.py | 8 ++++++-- apax/train/trainer.py | 17 +++++++++++++++-- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/apax/__init__.py b/apax/__init__.py index 7438bf4a..4fd5cdb2 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -1,5 +1,25 @@ import os +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=2" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + + import jax os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 889aad87..1f4db7e7 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -315,7 +315,7 @@ def find_largest_system2(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: return max_atoms, max_nbrs class Dataset: - def __init__(self, atoms, cutoff, bs, n_epochs, buffer_size, n_jit_steps= 1, name="train", pre_shuffle=False) -> None: + def __init__(self, atoms, cutoff, bs, n_jit_steps= 1, name="train", pre_shuffle=False) -> None: if pre_shuffle: shuffle(atoms) self.sample_atoms = atoms[0] @@ -324,8 +324,8 @@ def __init__(self, atoms, cutoff, bs, n_epochs, buffer_size, n_jit_steps= 1, nam finputs.update({k: v for k,v in inputs["ragged"].items()}) self.inputs = finputs - self.n_epochs = n_epochs - self.buffer_size = buffer_size + self.n_epochs = 100 + self.buffer_size = 100 max_atoms, max_nbrs = find_largest_system2(self.inputs, cutoff) self.max_atoms = max_atoms @@ -434,7 +434,7 @@ def init_input(self) -> Dict[str, np.ndarray]: inputs = jax.tree_map(lambda x: jnp.array(x), inputs) return inputs, np.array(box) - def shuffle_and_batch(self): + def shuffle_and_batch(self, sharding=None): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. @@ -455,10 +455,10 @@ def shuffle_and_batch(self): ) if self.n_jit_steps > 1: ds = ds.batch(batch_size=self.n_jit_steps) - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2, sharding) return ds - def batch(self) -> Iterator[jax.Array]: + def batch(self, sharding) -> Iterator[jax.Array]: gen = lambda: self ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) ds = (ds @@ -466,9 +466,9 @@ def batch(self) -> Iterator[jax.Array]: .repeat(self.n_epochs) .batch(batch_size=self.batch_size) ) - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2, sharding) return ds - + def cleanup(self): """Removes cache files from disk. Used after training diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 67f67303..a0d26cbf 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -108,7 +108,7 @@ def get_shrink_wrapped_cell(positions): return cell, cell_origin -def prefetch_to_single_device(iterator, size: int): +def prefetch_to_single_device(iterator, size: int, sharding = None): """ inspired by https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device @@ -117,7 +117,11 @@ def prefetch_to_single_device(iterator, size: int): queue = collections.deque() def _prefetch(x): - return jnp.asarray(x) + x = jnp.asarray(x) + if sharding: + shape = (2, *([1]*len(x.shape[1:]))) + x = jax.device_put(x, sharding.reshape(shape)) + return x def enqueue(n): for data in itertools.islice(iterator, n): diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 138efbc0..8451a98b 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -50,12 +50,22 @@ def fit( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) + + from jax.experimental import mesh_utils + from jax.sharding import PositionalSharding + sharding = PositionalSharding(mesh_utils.create_device_mesh((len(jax.devices()),))) + + jax.device_put(state, sharding.replicate()) + train_steps_per_epoch = train_ds.steps_per_epoch() - batch_train_ds = train_ds.shuffle_and_batch() + batch_train_ds = train_ds.shuffle_and_batch(sharding) if val_ds is not None: val_steps_per_epoch = val_ds.steps_per_epoch() - batch_val_ds = val_ds.batch() + batch_val_ds = val_ds.batch(sharding) + + + best_loss = np.inf early_stopping_counter = 0 @@ -74,6 +84,9 @@ def fit( callbacks.on_train_batch_begin(batch=batch_idx) batch = next(batch_train_ds) + + # print(jax.tree_map(lambda x: x.devices(), batch)) + ( (state, train_batch_metrics), batch_loss, From 0c9a931227a9cea20e2cbba35a95508013c84777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 22 Mar 2024 10:29:34 +0100 Subject: [PATCH 2/9] sketch of sharding compatible with njit --- apax/data/preprocessing.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index a0d26cbf..80dc157c 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -116,11 +116,29 @@ def prefetch_to_single_device(iterator, size: int, sharding = None): """ queue = collections.deque() - def _prefetch(x): - x = jnp.asarray(x) + n_devices = 2 + multistep_jit = True + slice_start = 1 + shape = [n_devices] + if multistep_jit: + # replicate over multi-batch axis + # data shape: njit x bs x ... + slice_start = 2 + shape.insert(0, 1) + + def _prefetch(x: jax.Array): + + print(x.shape) + # quit() + shape if sharding: - shape = (2, *([1]*len(x.shape[1:]))) + remaining_axes = [1]*len(x.shape[slice_start:]) + shape = tuple(shape + remaining_axes) x = jax.device_put(x, sharding.reshape(shape)) + print(x.devices()) + quit() + else: + x = jnp.asarray(x) return x def enqueue(n): From d75ef8888c6db93a156fc2f173b26e3ee56a410e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 11:21:35 +0200 Subject: [PATCH 3/9] removed debugging XLA options --- apax/__init__.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/apax/__init__.py b/apax/__init__.py index 2e2de880..5c5646ce 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -1,27 +1,8 @@ import os -# Set this to True to run the model on CPU only. -USE_CPU_ONLY = True - -flags = os.environ.get("XLA_FLAGS", "") -if USE_CPU_ONLY: - flags += " --xla_force_host_platform_device_count=2" # Simulate 8 devices - # Enforce CPU-only execution - os.environ["CUDA_VISIBLE_DEVICES"] = "" -else: - # GPU flags - flags += ( - "--xla_gpu_enable_triton_softmax_fusion=true " - "--xla_gpu_triton_gemm_any=false " - "--xla_gpu_enable_async_collectives=true " - "--xla_gpu_enable_latency_hiding_scheduler=true " - "--xla_gpu_enable_highest_priority_async_stream=true " - ) -os.environ["XLA_FLAGS"] = flags - - import jax + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" jax.config.update("jax_enable_x64", True) From 159edc337217e3be53543448590ded4d5c8c4f42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 11:22:38 +0200 Subject: [PATCH 4/9] added automatic dataparallel training --- apax/config/train_config.py | 3 +++ apax/data/input_pipeline.py | 24 ++++++++++++++++-------- apax/data/preprocessing.py | 32 +++++++++++++------------------- apax/train/run.py | 1 + apax/train/trainer.py | 23 +++++++++++++---------- 5 files changed, 46 insertions(+), 37 deletions(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 3cd5f57c..f8e083a9 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -298,6 +298,8 @@ class Config(BaseModel, frozen=True, extra="forbid"): callbacks: List of :class: `callback` configurations. progress_bar: Progressbar configuration. checkpoints: Checkpoint configuration. + data_parallel: Automatically uses all available GPUs for data parallel training. + Set to false to force single device training. """ n_epochs: PositiveInt @@ -305,6 +307,7 @@ class Config(BaseModel, frozen=True, extra="forbid"): seed: int = 1 n_models: int = 1 n_jitted_steps: int = 1 + data_parallel: int = True data: DataConfig model: ModelConfig = ModelConfig() diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 539d5ed8..dd794933 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -201,7 +201,7 @@ def __iter__(self): space = self.n_data - self.count self.enqueue(space) - def shuffle_and_batch(self): + def shuffle_and_batch(self, sharding=None): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. @@ -223,10 +223,12 @@ def shuffle_and_batch(self): ).batch(batch_size=self.batch_size) if self.n_jit_steps > 1: ds = ds.batch(batch_size=self.n_jit_steps) - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2, sharding) + ds = prefetch_to_single_device( + ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) return ds - def batch(self, sharding) -> Iterator[jax.Array]: + def batch(self, sharding=None) -> Iterator[jax.Array]: ds = ( tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() @@ -235,7 +237,9 @@ def batch(self, sharding) -> Iterator[jax.Array]: .repeat(self.n_epochs) ) ds = ds.batch(batch_size=self.batch_size) - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + ds = prefetch_to_single_device( + ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) return ds def cleanup(self): @@ -261,7 +265,7 @@ def __iter__(self): self.count = 0 self.enqueue(space) - def shuffle_and_batch(self): + def shuffle_and_batch(self, sharding=None): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. @@ -279,15 +283,19 @@ def shuffle_and_batch(self): ).batch(batch_size=self.batch_size) if self.n_jit_steps > 1: ds = ds.batch(batch_size=self.n_jit_steps) - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + ds = prefetch_to_single_device( + ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) return ds - def batch(self) -> Iterator[jax.Array]: + def batch(self, sharding=None) -> Iterator[jax.Array]: ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() ) ds = ds.batch(batch_size=self.batch_size) - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + ds = prefetch_to_single_device( + ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 + ) return ds diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 715cff24..dc2da7a4 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -53,7 +53,7 @@ def get_shrink_wrapped_cell(positions): return cell, cell_origin -def prefetch_to_single_device(iterator, size: int, sharding = None): +def prefetch_to_single_device(iterator, size: int, sharding=None, n_step_jit=False): """ inspired by https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device @@ -61,27 +61,21 @@ def prefetch_to_single_device(iterator, size: int, sharding = None): """ queue = collections.deque() - n_devices = 2 - multistep_jit = True - slice_start = 1 - shape = [n_devices] - if multistep_jit: - # replicate over multi-batch axis - # data shape: njit x bs x ... - slice_start = 2 - shape.insert(0, 1) + if sharding: + n_devices = len(sharding._devices) + slice_start = 1 + shape = [n_devices] + if n_step_jit: + # replicate over multi-batch axis + # data shape: njit x bs x ... + slice_start = 2 + shape.insert(0, 1) def _prefetch(x: jax.Array): - - print(x.shape) - # quit() - shape if sharding: - remaining_axes = [1]*len(x.shape[slice_start:]) - shape = tuple(shape + remaining_axes) - x = jax.device_put(x, sharding.reshape(shape)) - print(x.devices()) - quit() + remaining_axes = [1] * len(x.shape[slice_start:]) + final_shape = tuple(shape + remaining_axes) + x = jax.device_put(x, sharding.reshape(final_shape)) else: x = jnp.asarray(x) return x diff --git a/apax/train/run.py b/apax/train/run.py index fe408ab6..163e2849 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -145,4 +145,5 @@ def run(user_config, log_level="error"): patience=config.patience, disable_pbar=config.progress_bar.disable_epoch_pbar, is_ensemble=config.n_models > 1, + data_parallel=config.data_parallel, ) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index f9a99924..913b4351 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -30,6 +30,7 @@ def fit( patience: Optional[int] = None, disable_pbar: bool = False, is_ensemble=False, + data_parallel=True, ): log.info("Beginning Training") callbacks.on_train_begin() @@ -50,12 +51,15 @@ def fit( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) - from jax.experimental import mesh_utils from jax.sharding import PositionalSharding - sharding = PositionalSharding(mesh_utils.create_device_mesh((len(jax.devices()),))) - jax.device_put(state, sharding.replicate()) + devices = len(jax.devices()) + if devices > 1 and data_parallel: + sharding = PositionalSharding(mesh_utils.create_device_mesh((devices,))) + state = jax.device_put(state, sharding.replicate()) + else: + sharding = None train_steps_per_epoch = train_ds.steps_per_epoch() batch_train_ds = train_ds.shuffle_and_batch(sharding) @@ -64,9 +68,6 @@ def fit( val_steps_per_epoch = val_ds.steps_per_epoch() batch_val_ds = val_ds.batch(sharding) - - - best_loss = np.inf early_stopping_counter = 0 epoch_loss = {} @@ -120,10 +121,12 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update({ - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - }) + epoch_metrics.update( + { + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + } + ) epoch_metrics.update({**epoch_loss}) From 5933ded410efaa24e85ea05a4448a6acaec0a082 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 11:24:32 +0200 Subject: [PATCH 5/9] remove debug num epuchs and buffer size in input pipeline --- apax/data/input_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index dd794933..d4d6b439 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -53,8 +53,8 @@ def __init__( self.sample_atoms = atoms[0] self.inputs = atoms_to_inputs(atoms) - self.n_epochs = 100 - self.buffer_size = 100 + self.n_epochs = n_epochs + self.buffer_size = buffer_size max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff) self.max_atoms = max_atoms From 2b66ddf9e032a61d8ab7a1a3e9de578088c3b33b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Apr 2024 09:27:06 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/train/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 913b4351..17959eca 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -121,12 +121,10 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update( - { - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - } - ) + epoch_metrics.update({ + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + }) epoch_metrics.update({**epoch_loss}) From 59c0bd02000aaae3d81a9d8e1f1556ed352e15ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 11:28:50 +0200 Subject: [PATCH 7/9] moved sharding imports to the top --- apax/train/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 913b4351..e0e138e6 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -8,6 +8,8 @@ import jax.numpy as jnp import numpy as np from clu import metrics +from jax.experimental import mesh_utils +from jax.sharding import PositionalSharding from tqdm import trange from apax.data.input_pipeline import InMemoryDataset @@ -51,9 +53,6 @@ def fit( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) - from jax.experimental import mesh_utils - from jax.sharding import PositionalSharding - devices = len(jax.devices()) if devices > 1 and data_parallel: sharding = PositionalSharding(mesh_utils.create_device_mesh((devices,))) From f1a57aacfcca2d9423e23c9241856f9781c43f67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Apr 2024 09:29:03 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/__init__.py b/apax/__init__.py index 5c5646ce..de4b9e6e 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -2,7 +2,6 @@ import jax - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" jax.config.update("jax_enable_x64", True) From c119992aea3c73b4c4b1147987f6b698d1a8625b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 8 Apr 2024 10:13:53 +0200 Subject: [PATCH 9/9] remove debug comment --- apax/train/trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 00b5b18d..39e6c8fc 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -95,9 +95,6 @@ def fit( callbacks.on_train_batch_begin(batch=batch_idx) batch = next(batch_train_ds) - - # print(jax.tree_map(lambda x: x.devices(), batch)) - ( (state, train_batch_metrics), batch_loss,