Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataparallel Training #257

Merged
merged 14 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List of :class: `callback` <config.CallbackConfig> configurations.
progress_bar: Progressbar configuration.
checkpoints: Checkpoint configuration.
data_parallel: Automatically uses all available GPUs for data parallel training.
Set to false to force single device training.
"""

n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1
data_parallel: int = True

data: DataConfig
model: ModelConfig = ModelConfig()
Expand Down
24 changes: 16 additions & 8 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __iter__(self):
space = self.n_data - self.count
self.enqueue(space)

def shuffle_and_batch(self):
def shuffle_and_batch(self, sharding=None):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.

Expand All @@ -227,10 +227,12 @@ def shuffle_and_batch(self):
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def batch(self) -> Iterator[jax.Array]:
def batch(self, sharding=None) -> Iterator[jax.Array]:
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
Expand All @@ -239,7 +241,9 @@ def batch(self) -> Iterator[jax.Array]:
.repeat(self.n_epochs)
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def cleanup(self):
Expand All @@ -265,7 +269,7 @@ def __iter__(self):
self.count = 0
self.enqueue(space)

def shuffle_and_batch(self):
def shuffle_and_batch(self, sharding=None):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.

Expand All @@ -283,15 +287,19 @@ def shuffle_and_batch(self):
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def batch(self) -> Iterator[jax.Array]:
def batch(self, sharding=None) -> Iterator[jax.Array]:
ds = tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds


Expand Down
22 changes: 19 additions & 3 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,32 @@ def get_shrink_wrapped_cell(positions):
return cell, cell_origin


def prefetch_to_single_device(iterator, size: int):
def prefetch_to_single_device(iterator, size: int, sharding=None, n_step_jit=False):
"""
inspired by
https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
except it does not shard the data.
"""
queue = collections.deque()

def _prefetch(x):
return jnp.asarray(x)
if sharding:
n_devices = len(sharding._devices)
slice_start = 1
shape = [n_devices]
if n_step_jit:
# replicate over multi-batch axis
# data shape: njit x bs x ...
slice_start = 2
shape.insert(0, 1)

def _prefetch(x: jax.Array):
if sharding:
remaining_axes = [1] * len(x.shape[slice_start:])
final_shape = tuple(shape + remaining_axes)
x = jax.device_put(x, sharding.reshape(final_shape))
else:
x = jnp.asarray(x)
return x

def enqueue(n):
for data in itertools.islice(iterator, n):
Expand Down
1 change: 1 addition & 0 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,6 @@ def run(user_config, log_level="error"):
disable_pbar=config.progress_bar.disable_epoch_pbar,
disable_batch_pbar=config.progress_bar.disable_batch_pbar,
is_ensemble=config.n_models > 1,
data_parallel=config.data_parallel,
)
log.info("Finished training")
17 changes: 15 additions & 2 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import jax.numpy as jnp
import numpy as np
from clu import metrics
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
from tqdm import trange

from apax.data.input_pipeline import InMemoryDataset
Expand All @@ -31,6 +33,7 @@ def fit(
disable_pbar: bool = False,
disable_batch_pbar: bool = True,
is_ensemble=False,
data_parallel=True,
):
log.info("Beginning Training")
callbacks.on_train_begin()
Expand All @@ -51,12 +54,19 @@ def fit(
f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})"
)

devices = len(jax.devices())
if devices > 1 and data_parallel:
sharding = PositionalSharding(mesh_utils.create_device_mesh((devices,)))
state = jax.device_put(state, sharding.replicate())
else:
sharding = None

train_steps_per_epoch = train_ds.steps_per_epoch()
batch_train_ds = train_ds.shuffle_and_batch()
batch_train_ds = train_ds.shuffle_and_batch(sharding)

if val_ds is not None:
val_steps_per_epoch = val_ds.steps_per_epoch()
batch_val_ds = val_ds.batch()
batch_val_ds = val_ds.batch(sharding)

best_loss = np.inf
early_stopping_counter = 0
Expand Down Expand Up @@ -85,6 +95,9 @@ def fit(
callbacks.on_train_batch_begin(batch=batch_idx)

batch = next(batch_train_ds)

# print(jax.tree_map(lambda x: x.devices(), batch))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment

(
(state, train_batch_metrics),
batch_loss,
Expand Down
Loading