diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64ef0302e..ccd99e68d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json wmt_pytorch: runs-on: ubuntu-latest steps: @@ -54,7 +54,7 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json imagenet_jax: runs-on: ubuntu-latest steps: @@ -71,8 +71,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json imagenet_pytorch: runs-on: ubuntu-latest steps: @@ -89,8 +89,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json # uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved. criteo_jax: runs-on: ubuntu-latest @@ -142,8 +142,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json speech_pytorch: runs-on: ubuntu-latest steps: @@ -160,8 +160,8 @@ jobs: pip install .[pytorch_cpu] pip install .[full] pip install -e . - python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json - python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json + python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ogbg: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index 7d35f0ccc..916a29ff4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algoperf/_version.py +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..75d8d59ea 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -11,7 +11,6 @@ from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax import numpy as np from tensorflow.io import gfile # pytype: disable=import-error import torch @@ -193,10 +192,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..abd3f51b3 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,6 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler +from algoperf import sharding_utils from algoperf import spec @@ -50,6 +51,7 @@ def shard_and_maybe_pad_np( weights = batch.get('weights') # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights + naive_sharding_spec = sharding_utils.get_naive_sharding_spec() def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. @@ -60,10 +62,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return jax.device_put(x, naive_sharding_spec) return jax.tree.map(_prepare, batch) diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..24f981546 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -43,6 +43,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/sharding_utils.py b/algoperf/sharding_utils.py new file mode 100644 index 000000000..fc6b38d4a --- /dev/null +++ b/algoperf/sharding_utils.py @@ -0,0 +1,82 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec + + +def get_mesh() -> jax.sharding.Mesh: + """Creates a mesh from all available GPUs. + Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + + +def get_replicated_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def shard_replicated(x, mesh=None): + """Shards a tensor across all devices.""" + if mesh is None: + mesh = get_mesh() + return jax.tree.map( + lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x) + + +def get_naive_sharding_spec(mesh=None): + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) + + +def get_naive_sharding(x, mesh=None): + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if len(x.shape) > 0 and x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) + + +def shard_params(params, mesh=None): + """Shards a parameter tree across all devices + with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)), + params) + + +def shard_naive(x, mesh=None): + return shard_params(x, mesh) + + +def get_naive_sharding_tree(input_tree, mesh=None): + if mesh is None: + mesh = get_mesh() + return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree) + + +def get_sharding_tree(params, mesh=None): + """Returns a sharding tree for a parameter tree.""" + return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params) + + +def get_empty_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def disp_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..18cb9ac5b 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,7 +8,6 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -171,5 +170,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + # FIXME(rka97): Figure out how to do prefetching+sharding. + # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax @@ -13,6 +12,7 @@ import tensorflow_datasets as tfds from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -28,15 +28,16 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: - ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) - train = split == 'train' + data_dir = data_dir + "/cifar10" + ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + train = split == "train" assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ['train', 'eval_train']: - split = f'train[:{self.num_train_examples}]' - elif split == 'validation': - split = f'train[{self.num_train_examples}:]' + if split in ["train", "eval_train"]: + split = f"train[:{self.num_train_examples}]" + elif split == "validation": + split = f"train[{self.num_train_examples}:]" ds = create_input_iter( split, ds_builder, @@ -48,7 +49,8 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( @@ -59,7 +61,8 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -68,40 +71,28 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, 'ResNet18') + model_cls = getattr(models, "ResNet18") model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, + variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == 'Dense_0' + return param_key == "Dense_0" def model_fn( self, @@ -115,11 +106,11 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {'params': params, **model_state} + variables = {"params": params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn) @@ -127,7 +118,7 @@ def model_fn( else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn) @@ -140,13 +131,15 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, + 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -159,51 +152,69 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + "summed": summed_loss, + "n_valid_examples": n_valid_examples, + "per_example": per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)['summed'] + summed_loss = self.loss_fn(labels, logits, weights)["summed"] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, - } - metrics = lax.psum(metrics, axis_name='batch') - return metrics + return jnp.array(summed_loss), jnp.array(accuracy) - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + ) + def _per_device_eval_model( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + weights = batch.get("weights") + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch["targets"], weights) + + losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) + metrics = { + "loss": + jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, + "accuracy": + (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies + ), + } + return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree_map(lambda x: x / num_examples, total_metrics) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..30ecd2b00 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf import sharding_utils from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -105,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return sharding_utils.shard_replicated(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -129,13 +130,16 @@ def model_fn( return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding()) + def _eval_batch_jitted(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( params, batch, @@ -156,8 +160,7 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..2709ef1e6 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -10,6 +10,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import sharding_utils import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim @@ -39,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_utils.replicate(params) + params = sharding_utils.shard_replicated(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -94,10 +95,12 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding()), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], @@ -126,7 +129,6 @@ def _eval_model(self, 'ssim': ssim_sum, 'loss': summed_loss, } - metrics = jax.lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -154,13 +156,12 @@ def _eval_model_on_split(self, num_batches=num_batches) total_metrics = {'ssim': 0., 'loss': 0.} - eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. - synced_metrics = self._eval_model(params, batch, eval_rngs) + synced_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..b1be5ac1f 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -399,6 +399,6 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. - it = jax_utils.prefetch_to_device(it, 2) + # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..87b9b82bd 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -20,6 +20,7 @@ from algoperf import param_utils from algoperf import random_utils as prng +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -71,17 +72,6 @@ def _build_dataset( use_randaug=use_randaug) return ds - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() # Create a shallow copy - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, @@ -113,18 +103,30 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, + sharding_utils.get_replicated_sharding(mesh)), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, + sharding_utils.get_replicated_sharding(mesh)), + model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -218,7 +220,6 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -231,9 +232,6 @@ def _eval_model_on_split(self, data_dir: str, global_step: int = 0) -> Dict[str, float]: del global_step - if model_state is not None: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) data_rng, eval_rng = prng.split(rng, 2) # We already repeat the dataset indefinitely in tf.data. @@ -250,20 +248,14 @@ def _eval_model_on_split(self, eval_metrics = {} for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) - step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) - # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - step_eval_rngs) + synced_metrics = self._eval_model(params, batch, model_state, eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), - eval_metrics) + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..1a2ce6342 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -2,13 +2,13 @@ from typing import Dict, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax import jax.numpy as jnp from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload @@ -46,8 +46,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = sharding_utils.shard_replicated(params) + model_state = sharding_utils.shard_replicated(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..d1bf29ba4 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,11 +2,9 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp import numpy as np import optax @@ -14,6 +12,7 @@ from algoperf import data_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics from algoperf.workloads.librispeech_conformer import workload @@ -93,8 +92,11 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + params = sharding_utils.shard_replicated(params) + model_state = sharding_utils.shard_replicated(model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -180,6 +182,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -305,11 +308,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + out_shardings=sharding_utils.get_naive_sharding_spec(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -325,15 +333,45 @@ def eval_step_pmapped( decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) - targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': + loss['per_example'], + 'decoded': + decoded, + 'decoded_paddings': + decoded_paddings, + 'targets': + targets, + 'target_paddings': + target_paddings, + 'n_valid_examples': + jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics_bundle = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics_bundle def _eval_model_on_split(self, split: str, @@ -346,9 +384,6 @@ def _eval_model_on_split(self, global_step: int = 0) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step - if model_state is not None and len(model_state) > 0: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: @@ -358,10 +393,7 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, eval_batch, model_state, rng) if metrics_report is None: metrics_report = computed_metrics @@ -373,16 +405,6 @@ def _eval_model_on_split(self, return computed_metrics - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..1cab9e89e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -8,8 +8,13 @@ # webpage : https://bastings.github.io/ """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +from absl import logging +import numpy as np + +import functools +import flax from flax import linen as nn from flax import struct import jax @@ -310,16 +315,12 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v variance = (inputs - mean) * (inputs - mean) * mask sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean @@ -389,6 +390,23 @@ def __call__( seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32) if use_cuda: + inputs_shape = np.shape(inputs) + h_0_shape = np.shape(h_0) + c_0_shape = np.shape(c_0) + weights_shape = np.shape(weights) + seq_lengths_np = np.shape(seq_lengths) + + n = jax.devices() + # logging.info(f"jax num devices {n}") + # logging.info(f'inputs shape {inputs_shape}') + # logging.info(f'h_0 shape {h_0_shape}') + # logging.info(f'c_0 shape {c_0_shape}') + # logging.info(f'seq_lengths shape {seq_lengths_np}') + # logging.info(f'weights_shape {weights_shape}') + # logging.info(f'input_size {input_size}') + # logging.info(f'hidden_size {self.features}') + # logging.info(f'num_layers {self.num_layers}') + y, h, c = rnn.lstm( x=inputs, h_0=h_0, c_0=c_0, weights=weights, seq_lengths=seq_lengths, input_size=input_size, @@ -425,6 +443,324 @@ def unpack_weights( ) +### Swap in regular LSTM layer for debuggin +@jax.vmap +def flip_sequences(inputs: Array, lengths: Array) -> Array: + """Flips a sequence of inputs along the time dimension. + + This function can be used to prepare inputs for the reverse direction of a + bidirectional LSTM. It solves the issue that, when naively flipping multiple + padded sequences stored in a matrix, the first elements would be padding + values for those sequences that were padded. This function keeps the padding + at the end, while flipping the rest of the elements. + + Example: + ```python + inputs = [[1, 0, 0], + [2, 3, 0] + [4, 5, 6]] + lengths = [1, 2, 3] + flip_sequences(inputs, lengths) = [[1, 0, 0], + [3, 2, 0], + [6, 5, 4]] + ``` + + Args: + inputs: An array of input IDs [batch_size, seq_length]. + lengths: The length of each sequence [batch_size]. + + Returns: + An ndarray with the flipped inputs. + """ + # Compute the indices to put the inputs in flipped order as per above example. + max_length = inputs.shape[0] + idxs = (jnp.arange(max_length - 1, -1, -1) + lengths) % max_length + return inputs[idxs] + + +class GenericRNNSequenceEncoder(nn.Module): + """Encodes a single sequence using any RNN cell, for example `nn.LSTMCell`. + + The sequence can be encoded left-to-right (default) or right-to-left (by + calling the module with reverse=True). Regardless of encoding direction, + outputs[i, j, ...] is the representation of inputs[i, j, ...]. + + Attributes: + hidden_size: The hidden size of the RNN cell. + cell_type: The RNN cell module to use, for example, `nn.LSTMCell`. + cell_kwargs: Optional keyword arguments for the recurrent cell. + recurrent_dropout_rate: The dropout to apply across time steps. If this is + greater than zero, you must use an RNN cell that implements + `RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. + """ + hidden_size: int + cell_type: Type[nn.RNNCellBase] + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + recurrent_dropout_rate: float = 0.0 + + def setup(self): + self.cell = self.cell_type(features=self.hidden_size, **self.cell_kwargs) + + @functools.partial( # Repeatedly calls the below method to encode the inputs. + nn.transforms.scan, + variable_broadcast='params', + in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), + out_axes=1, + split_rngs={'params': False}) + def unroll_cell(self, + cell_state: StateType, + inputs: Array, + recurrent_dropout_mask: Optional[Array], + deterministic: bool): + """Unrolls a recurrent cell over an input sequence. + + Args: + cell_state: The initial cell state, shape: [batch_size, + hidden_size] (or an n-tuple thereof). + inputs: The input sequence. [batch_size, seq_len, input_dim]. + recurrent_dropout_mask: An optional recurrent dropout mask to apply in + between time steps. [batch_size, hidden_size]. + deterministic: Disables recurrent dropout when set to True. + + Returns: + The cell state after processing the complete sequence (including padding), + and a tuple with all intermediate cell states and cell outputs. + """ + # We do not directly scan the cell itself, since it only returns the output. + # This returns both the state and the output, so we can slice out the + # correct final states later. + new_cell_state, output = self.cell(cell_state, inputs) + return new_cell_state, (new_cell_state, output) + + def __call__(self, + inputs: Array, + lengths: Array, + initial_state: StateType, + reverse: bool = False, + deterministic: bool = False): + """Unrolls the RNN cell over the inputs. + + Arguments: + inputs: A batch of sequences. Shape: [batch_size, seq_len, + input_dim]. + lengths: The lengths of the input sequences. + initial_state: The initial state for the RNN cell. Shape: [batch_size, + hidden_size]. + reverse: Process the inputs in reverse order, and reverse the outputs. + This means that the outputs still correspond to the order of the inputs, + but their contexts come from the right, and not from the left. + deterministic: Disables recurrent dropout if set to True. + + Returns: + The encoded sequence of inputs, shaped [batch_size, seq_len, + hidden_size], as well as the final hidden states of the RNN cell. For an + LSTM cell the final states are a tuple (c, h), each shaped [ + batch_size, hidden_size]. + """ + if reverse: + inputs = flip_sequences(inputs, lengths) + + recurrent_dropout_mask = None + _, (cell_states, outputs) = self.unroll_cell(initial_state, + inputs, + recurrent_dropout_mask, + deterministic) + final_state = jax.tree.map( + lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states) + + if reverse: + outputs = flip_sequences(outputs, lengths) + + return outputs, final_state + + +class GenericRNN(nn.Module): + """Generic RNN class. + + This provides generic RNN functionality to encode sequences with any RNN cell. + The class provides unidirectional and bidirectional layers, and these are + stacked when asking for multiple layers. + + This class be used to create a specific RNN class such as LSTM or GRU. + + Attributes: + cell_type: An RNN cell class to use, e.g., `flax.linen.LSTMCell`. + hidden_size: The size of each recurrent cell. + num_layers: The number of stacked recurrent layers. The output of the first + layer, with optional dropout applied, feeds into the next layer. + dropout_rate: Dropout rate to be applied between LSTM layers. Only applies + when num_layers > 1. + recurrent_dropout_rate: Dropout rate to be applied on the hidden state at + each time step repeating the same dropout mask. + bidirectional: Process the sequence left-to-right and right-to-left and + concatenate the outputs from the two directions. + cell_kwargs: Optional keyword arguments to instantiate the cell with. + """ + cell_type: Type[nn.RNNCellBase] + hidden_size: int + num_layers: int = 1 + dropout_rate: float = 0. + recurrent_dropout_rate: float = 0. + bidirectional: bool = False + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + + @nn.compact + def __call__( + self, + inputs: Array, + lengths: Array, + initial_states: Optional[Sequence[StateType]] = None, + deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + """Processes the input sequence using the recurrent cell. + + Args: + inputs: The input sequence [batch_size, sequence_length, ...] + lengths: The lengths of each sequence in the batch. [batch_size] + initial_states: The initial states for the cells. You must provide + `num_layers` initial states (when using bidirectional, `num_layers * + 2`). + These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, + all initial states will be initialized with zeros. + deterministic: Disables dropout between layers when set to True. + Returns: + The sequence of all outputs for the final layer, and a list of final + states for each cell and direction. Directions are alternated (first + forward, then backward, if bidirectional). For example for a bidirectional + cell this would be: layer 1 forward, layer 1 backward, layer 2 forward, + layer 2 backward, etc.. + For some cells like LSTMCell a state consists of an (c, h) tuple, while + for others cells it only contains a single vector (h,). + """ + batch_size = inputs.shape[0] + final_states = [] + num_directions = 2 if self.bidirectional else 1 + num_cells = self.num_layers * num_directions + + # Construct initial states. + if initial_states is None: # Initialize with zeros. + rng = jax.random.PRNGKey(0) + initial_states = [ + self.cell_type(self.hidden_size).initialize_carry( + rng, (batch_size, 1)) for _ in range(num_cells) + ] + if len(initial_states) != num_cells: + raise ValueError( + f'Please provide {self.num_cells} (`num_layers`, *2 if bidirectional)' + 'initial states.') + + # For each layer, apply the forward and optionally the backward RNN cell. + cell_idx = 0 + for _ in range(self.num_layers): + # Unroll an RNN cell (forward direction) for this layer. + outputs, final_state = GenericRNNSequenceEncoder( + cell_type=self.cell_type, + cell_kwargs=self.cell_kwargs, + hidden_size=self.hidden_size, + recurrent_dropout_rate=self.recurrent_dropout_rate, + name=f'{self.name}SequenceEncoder_{cell_idx}')( + inputs, + lengths, + initial_state=initial_states[cell_idx], + deterministic=deterministic) + final_states.append(final_state) + cell_idx += 1 + + # Unroll an RNN cell (backward direction) for this layer. + if self.bidirectional: + backward_outputs, backward_final_state = GenericRNNSequenceEncoder( + cell_type=self.cell_type, + cell_kwargs=self.cell_kwargs, + hidden_size=self.hidden_size, + recurrent_dropout_rate=self.recurrent_dropout_rate, + name=f'{self.name}SequenceEncoder_{cell_idx}')( + inputs, + lengths, + initial_state=initial_states[cell_idx], + reverse=True, + deterministic=deterministic) + outputs = jnp.concatenate([outputs, backward_outputs], axis=-1) + final_states.append(backward_final_state) + cell_idx += 1 + + inputs = outputs + + return outputs, final_states + + +class LSTM(nn.Module): + """LSTM. + + Attributes: + hidden_size: The size of each recurrent cell. + num_layers: The number of stacked recurrent layers. The output of the first + layer, with optional dropout applied, feeds into the next layer. + dropout_rate: Dropout rate to be applied between LSTM layers. Only applies + when num_layers > 1. + recurrent_dropout_rate: Dropout rate to be applied on the hidden state at + each time step repeating the same dropout mask. + bidirectional: Process the sequence left-to-right and right-to-left and + concatenate the outputs from the two directions. + cell_type: The LSTM cell class to use. Default: + `flax.linen.OptimizedLSTMCell`. If you use hidden_size of >2048, consider + using `flax.linen.LSTMCell` instead, since the optimized LSTM cell works + best for hidden sizes up to 2048. + cell_kwargs: Optional keyword arguments to instantiate the cell with. + """ + hidden_size: int + num_layers: int = 1 + dropout_rate: float = 0. + recurrent_dropout_rate: float = 0. + bidirectional: bool = False + cell_type: Any = nn.OptimizedLSTMCell + cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() + + @nn.compact + def __call__( + self, + inputs: Array, + lengths: Array, + initial_states: Optional[Sequence[StateType]] = None, + deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + """Processes an input sequence with an LSTM cell. + + Example usage: + ``` + inputs = np.random.normal(size=(2, 3, 4)) + lengths = np.array([1, 3]) + outputs, final_states = LSTM(hidden_size=10).apply(rngs, inputs, lengths) + ``` + + Args: + inputs: The input sequence [batch_size, sequence_length, ...] + lengths: The lengths of each sequence in the batch. [batch_size] + initial_states: The initial states for the cells. You must provide + `num_layers` initial states (when using bidirectional, `num_layers * + 2`). These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, + all initial states will be initialized with zeros. + deterministic: Disables dropout between layers when set to True. + + Returns: + The sequence of all outputs for the final layer, and a list of final + states (h, c) for each cell and direction, ordered first by layer number + and then by direction (first forward, then backward, if bidirectional). + """ + return GenericRNN( + cell_type=self.cell_type, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + dropout_rate=self.dropout_rate, + recurrent_dropout_rate=self.recurrent_dropout_rate, + bidirectional=self.bidirectional, + cell_kwargs=self.cell_kwargs, + name='LSTM')( + inputs, + lengths, + initial_states=initial_states, + deterministic=deterministic) + + class BatchRNN(nn.Module): """Implements a single deepspeech encoder layer. """ @@ -443,6 +779,18 @@ def __call__(self, inputs, input_paddings, train): config.batch_norm_epsilon)(inputs, input_paddings, train) + + # For regular LSTM + hidden_size = ( + config.encoder_dim // 2 if config.bidirectional else config.encoder_dim) + lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32) + + # output, _ = LSTM( + # hidden_size=hidden_size, + # bidirectional=config.bidirectional, + # num_layers=1, + # )(inputs, lengths) + output = CudnnLSTM( features=config.encoder_dim // 2, bidirectional=config.bidirectional, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..c29e952cb 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -4,10 +4,13 @@ from flax import jax_utils import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P import numpy as np from algoperf import param_utils from algoperf import spec +from algoperf import sharding_utils from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -51,8 +54,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = sharding_utils.shard_replicated(model_state) + params = sharding_utils.shard_replicated(params) return params, model_state def model_fn( @@ -65,6 +68,21 @@ def model_fn( update_batch_norm: bool, use_running_average_bn: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + model_fn_sharded = shard_map(model_fn_ref, + self.mesh, + ) + + def model_fn_ref( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py new file mode 100644 index 000000000..db345700e --- /dev/null +++ b/algoperf/workloads/lm/input_pipeline.py @@ -0,0 +1,154 @@ +"""Input pipeline for a LM dataset.""" +import functools +import os +from typing import Optional + +import jax +import jax.numpy as jnp +import tensorflow as tf +import torch +import torch.nn.functional as F +from transformers import GPT2Tokenizer + +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup +from datasets import load_dataset +from datasets import load_from_disk + +RANK = pytorch_setup()[1] +# Avoid multithreading in all processes but the first (rank 0). +# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# automatic optimization (AUTOTUNE), while other processes disable it (None). +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine +# the optimal number of elements to prefetch or parallelize for dataset +# operations, improving performance. +AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None + + +def get_hf_dataloader(cache_dir: str, + data_rng: jax.random.PRNGKey, + batch_size: int = 8, + seq_len: int = 32, + framework: str = "torch", + split="train"): + """ + Create a data loader from HuggingFace's FineWeb dataset. + + Args: + cache_dir: Directory to cache the dataset + batch_size: Number of sequences per batch + seq_len: Length of each sequence + framework: Either "torch" or "jax" to specify output tensor type + split: Dataset split to load + """ + # Initialize tokenizer and get vocab size + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + # Load the FineWeb dataset in streaming mode + fw = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-10BT", + split=split, + streaming=True, + cache_dir=cache_dir) + fw = fw.batch(batch_size=batch_size, drop_last_batch=True) + if split in ['train', 'eval_train']: + fw = fw.shuffle(seed=int(data_rng[-1])) + + def _tokenize(x): + """Tokenize and pad text to seq_len+1 tokens.""" + if framework == "torch": + tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) + elif framework == "jax": + tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = jnp.pad( + tokens, + pad_length, + mode="constant", + constant_values=tokenizer.pad_token_id) + return tokens[:seq_len + 1] + + def batch_iterator(): + for doc in fw: + if framework == "torch": + token_ids = torch.stack([_tokenize(x) for x in doc['text']]) + # Take first seq_len+1 tokens and convert to one-hot + tokens = F.one_hot(token_ids, num_classes=vocab_size).float() + # Split into input/target + inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] + inputs, targets = inputs.to("cuda"), targets.to("cuda") + elif framework == "jax": + token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) + tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) + inputs, targets = tokens[:, :-1], tokens[:, 1:] + inputs, targets = jax.device_put(inputs), jax.device_put(targets) + yield {'inputs': inputs, 'targets': targets} + + return batch_iterator() + + +def get_lm_dataset(data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None): + """Load HF dataset and return a TF dataset.""" + + dataset_path = os.path.join(data_dir, split) + dataset = load_from_disk(dataset_path) + + is_training = split == "train" + shuffle = split in ['train', 'eval_train'] + + dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? + + def tf_generator(): + """Generates data in a TensorFlow-friendly format.""" + for example in dataset: + yield { + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], + } + + # Create a TensorFlow dataset + ds = tf.data.Dataset.from_generator( + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + }) + + # Avoid creating too many threads when using PyTorch DDP. + # Limits TensorFlow's threading for non-primary processes (RANK != 0) + if RANK != 0: + options = tf.data.Options() + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + + if shuffle: + ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) + + if is_training: + ds = ds.repeat() + + # Batch the dataset, grouping consecutive elements into fixed-size chunks. + ds = ds.batch(global_batch_size, drop_remainder=is_training) + ds = ds.prefetch(AUTOTUNE) + + # Limit the dataset to a fixed number of batches if `num_batches` is specified + if num_batches: + ds = ds.take(num_batches) + + # Shard the dataset across multiple GPUs/TPUs if necessary + ds = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) + + return ds diff --git a/algoperf/workloads/lm/lm_jax/__init__.py b/algoperf/workloads/lm/lm_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py new file mode 100644 index 000000000..72ee5bd83 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -0,0 +1,19 @@ +from flax import linen as nn +import jax.numpy as jnp + +class LinearModel(nn.Module): + vocab_size: int + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense( + 10, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(inputs) + return nn.Dense( + self.vocab_size, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros, + name="output" + )(x) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py new file mode 100644 index 000000000..d21fd5090 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -0,0 +1,345 @@ +# Self-contained version of the DecoderOnly Transformer from NanoDO + +import dataclasses +from functools import partial + +from flax import linen as nn +import jax +import jax.numpy as jnp + +# =========== Transformer Decoder-only Model ========== + + + +@dataclasses.dataclass +class DoConfig: + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: DoConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD + +@partial(jax.jit, static_argnums=(0,1,2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack([ + jnp.cos(freqs)[None, :, None, :], + jnp.sin(freqs)[None, :, None, :] + ], axis=3) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack([ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] + ], axis=-1) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name="query") + self.multilinear_key = self.multilinear(name="key") + self.multilinear_value = self.multilinear(name="value") + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name="attn_out_proj", + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 + + # Compute attention scores + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: DoConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name="output_proj" + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + batch_size = y_BxL.shape[0] + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print(f"\nModel Configuration:") + print(f" - Model dimension (D): {cfg.D}") + print(f" - Number of heads (H): {cfg.H}") + print(f" - Max sequence length (L): {cfg.L}") + print(f" - Number of layers (N): {cfg.N}") + print(f" - Vocabulary size (V): {cfg.V}") + print(f" - Feed forward dimension (F): {cfg.F}") + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print("\nInitializing model parameters...") + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Total parameters: {param_count:,}") + + # Make a prediction (forward pass) + print("\nRunning forward pass...") + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") + print(f"Output data type: {logits.dtype}") + + # Print sample logits (first 5 positions of the first sequence) + print("\nSample logits (first sequence, first 5 positions, first 5 values):") + for position in range(min(5, L)): + print(f" Position {position}: {logits[0, position, :5]}") + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + # Test the predict function + print("\nTesting predict function...") + # Use a shorter + short_seq = x_BxL[:, :10] + print(f"Input sequence shape: {short_seq.shape}") + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py new file mode 100644 index 000000000..e73a5bfaf --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -0,0 +1,150 @@ +"""LM workload implemented in Jax.""" + +from typing import Dict, Optional, Tuple + +import jax +import jax.numpy as jnp +import optax +from flax import jax_utils +from algoperf import param_utils +from algoperf import sharding_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + TransformerDo, DoConfig, init_rope, apply_rope) + + +class LmWorkload(BaseLmWorkload): + """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using pre-cached FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + return loader + + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using HuggingFace FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="jax", + split=split) + return loader + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + + # Initialize NanoDO transformer model + cfg = DoConfig( + D=512, # model dim + H=8, # num heads + L=self._seq_len, + N=6, # num layers + V=self._vocab_size, + F=2048, # feedforward dim + dtype=jnp.float32 + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(self._model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.int32)) + params = variables['params'] + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + params = sharding_utils.shard_replicated(params) + model_state = None + return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state + inputs = batch['inputs'] + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + # Convert one-hot labels to token IDs if needed + if len(label_batch.shape) == len(logits_batch.shape): # one-hot + label_batch = jnp.argmax(label_batch, axis=-1) + + # Reshape for sequence modeling + logits = logits_batch.reshape(-1, logits_batch.shape[-1]) + labels = label_batch.reshape(-1) + + # Compute cross-entropy loss + loss = -jnp.sum( + jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) + + if mask_batch is not None: + mask = mask_batch.reshape(-1) + loss = loss * mask + n_valid = mask.sum() + else: + n_valid = labels.shape[0] + + return { + 'summed': loss, + 'n_valid_examples': n_valid, + 'per_example': loss / n_valid # Return per-token loss + } + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/lm/lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py new file mode 100644 index 000000000..545763924 --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class LinearLayer(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.bottleneck = nn.Linear(vocab_size, 512) + self.output = nn.Linear(512, vocab_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.bottleneck.weight, std=0.02) + nn.init.zeros_(self.bottleneck.bias) + nn.init.normal_(self.output.weight, std=0.02) + nn.init.zeros_(self.output.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py new file mode 100644 index 000000000..627a0e16d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -0,0 +1,298 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from dataclasses import dataclass +from typing import Tuple + + + +@dataclass +class ModelConfig: + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = False + + +class MLP(nn.Module): + + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + condense_ratio: int = 1): + inv_freqs = 1.0 / (theta**(torch.arange( + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) + t = torch.arange(end, dtype=torch.float32, + device=inv_freqs.device) / condense_ratio + freqs = torch.outer(t, inv_freqs).float() + return torch.stack([ + torch.cos(freqs)[None, :, None, :], + torch.sin(freqs)[None, :, None, :] + ], + dim=4) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - + qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads + + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + k = k.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + v = v.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, + d) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)]) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer('freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], + persistent=False) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + # For debugging + predictions = [] + + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Print top 5 tokens for debugging + if i == 0: + print("\nPyTorch detailed prediction:") + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + predictions.append(next_token.item()) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + print(f" Full predictions step by step: {predictions}") + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith("fc2.weight"): # mlp/glu output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + if n.endswith("w_out.weight"): # attn output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print("Initializing transformer model and running forward pass...") + + seq_length = 512 + + # Define model configuration + config = ModelConfig( + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=768, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=12, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True # Tie embedding and output weights + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py new file mode 100644 index 000000000..36e441e7e --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -0,0 +1,179 @@ +"""LM workload implemented in PyTorch.""" + +from typing import Dict, Iterator, Optional, Tuple + +import jax +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + + if hasattr(self, '_model'): + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None + + torch.manual_seed(rng[0]) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=512, # Model dimension + expand=4, # MLP expansion factor + n_layers=6, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True + ) + self._model = Transformer(cfg) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + del model_state, rng, update_batch_norm + model = params + + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + + logits = model(inputs) + return logits, None + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len + weights = None + + dtype = torch.long + is_train = split == 'train' + + for batch in loader: + inputs = batch['inputs'] + targets = batch['targets'] + + if USE_PYTORCH_DDP: + if not is_train: + # During eval, the batch size of the remainder might be different + per_device_batch_size = torch.tensor( + targets.shape[0], dtype=dtype, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + + # Broadcast to all devices + dist.broadcast(inputs, src=0) + dist.broadcast(targets, src=0) + + if weights is None: + batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() + weights = torch.ones((batch_size, seq_len), device=DEVICE) + batch = { + 'inputs': inputs, + 'targets': targets, + 'weights': weights, + } + yield batch + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + + # Handle both one-hot and token ID targets + targets = batch['targets'] + if targets.dim() == 3: # one-hot + loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) + else: # token IDs + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction='sum' + ) + return loss + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in PyTorch.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) + loss = -torch.sum(label_batch * log_probs, dim=-1) + else: + # Dense labels + loss = torch.nn.functional.cross_entropy( + logits_batch, + label_batch, + reduction='none') + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py new file mode 100644 index 000000000..639e71491 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -0,0 +1,81 @@ +import jax +import torch + +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # batch = next(input_queue) + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs: {inputs}") + + # Start test. + for _ in range(100): + + batch = next(input_queue) + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:, 1:], targets[:, :-1]) + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py new file mode 100644 index 000000000..36bab0d02 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py @@ -0,0 +1,116 @@ +"""Tests for LM HuggingFace input pipeline.""" +import os + +import jax +import jax.numpy as jnp +import torch +from transformers import GPT2Tokenizer + +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + +def main(): + # Setup test environment + cache_dir = "/home/ak4605/data" + if not os.path.exists(cache_dir): + raise FileNotFoundError(f"Cache directory {cache_dir} not found") + + data_rng = jax.random.PRNGKey(42) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + + print("Running JAX output shapes and types test...") + batch_size = 8 + seq_len = 32 + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == jnp.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == jnp.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" + assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" + print("✓ JAX test passed") + + print("\nRunning Torch output shapes and types test...") + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="torch", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == torch.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == torch.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" + assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" + print("✓ Torch test passed") + + print("\nTesting consistent batching with same seed...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" + print("✓ Consistent batching test passed") + + print("\nTesting eval split doesn't shuffle...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(999)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" + print("✓ Eval no shuffling test passed") + + print("\nAll tests passed successfully!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py new file mode 100644 index 000000000..31cd1d577 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_linear_model.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import torch + +TEST_SEQ_LEN = 512 + +def test_pytorch_linear(): + from algoperf.workloads.lm.lm_pytorch.models import LinearLayer + vocab_size = 32000 + model = LinearLayer(vocab_size) + + batch_size = 8 + seq_len = TEST_SEQ_LEN + inputs = torch.randn(batch_size, seq_len, vocab_size) + outputs = model(inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not torch.isnan(outputs).any() + +def test_jax_linear(): + from algoperf.workloads.lm.lm_jax.models import LinearModel + + vocab_size = 32000 + seq_len = TEST_SEQ_LEN + batch_size = 8 + model = LinearModel(vocab_size) + rng = jax.random.PRNGKey(0) + params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) + + inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) + outputs = model.apply(params, inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not jnp.isnan(outputs).any() + +if __name__ == '__main__': + test_pytorch_linear() + test_jax_linear() + print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py new file mode 100644 index 000000000..6b71c7952 --- /dev/null +++ b/algoperf/workloads/lm/workload.py @@ -0,0 +1,179 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Dict, Optional + +from absl import flags +import jax +import torch.distributed as dist + +from algoperf import spec +from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + +FLAGS = flags.FLAGS + +USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ + + +class BaseLmWorkload(spec.Workload): + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 5 + warmup_factor: float = 0.1 + + def __init__(self) -> None: + super().__init__() + self._param_shapes = None + self._param_types = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + return 20.0 # Target perplexity + + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/ppl'] <= self.test_target_value + + @property + def test_target_value(self) -> float: + return 20.0 # Target perplexity + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + return 1000000 # Example size + + @property + def num_eval_train_examples(self) -> int: + return 10000 # Subset for evaluation + + @property + def num_validation_examples(self) -> int: + return 50000 + + @property + def num_test_examples(self) -> int: + return 50000 + + @property + def eval_batch_size(self) -> int: + return 8 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + return 3600 * 4 # 4 hours + + @property + def eval_period_time_sec(self) -> int: + return 600 # 10 minutes + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 100000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue for the given split.""" + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + + loss_dict = self.loss_fn(batch['targets'], logits) + return loss_dict['summed'] + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True) + + loss = 0.0 + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch, model_state, rng) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {'loss': mean_loss} + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + @abc.abstractmethod + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..15fc6ff89 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -6,11 +6,11 @@ from flax import jax_utils from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.mnist.workload import BaseMnistWorkload @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,14 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +129,14 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + total_metrics = { + 'accuracy': total_metrics['accuracy'].item() / num_examples, + 'loss': total_metrics['loss'].item() / num_examples + } + return total_metrics diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..9506a5343 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -148,10 +148,15 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: + # yield { + # 'inputs': jraph.batch(graphs_shards), + # 'targets': np.vstack(labels_shards), + # 'weights': np.vstack(weights_shards) + # } def f(x): - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) - + return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:]) + graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) weights_shards = f(weights_shards) diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..9607a14e1 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -2,6 +2,7 @@ # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. from typing import Optional, Tuple +import jax from flax import linen as nn import jax.numpy as jnp import jraph @@ -78,7 +79,8 @@ def __call__(self, graph, train): self.hidden_dims, dropout=dropout, activation_fn=activation_fn), update_global_fn=_make_mlp( self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) - + # jax.debug.print(str(graph)) + graph = net(graph) # Map globals to represent the final result diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..347f89721 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,6 +8,7 @@ import jraph import optax +from algoperf import sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.ogbg import metrics @@ -45,7 +46,8 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(params), None + params = sharding_utils.shard_replicated(params) + return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' @@ -106,11 +108,20 @@ def _eval_metric(self, labels, logits, masks): return metrics.EvalMetrics.single_from_model_output( loss=loss['per_example'], logits=logits, labels=labels, mask=masks) + # @functools.partial( + # jax.pmap, + # axis_name='batch', + # in_axes=(None, 0, 0, 0, None), + # static_broadcasted_argnums=(0,)) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), + sharding_utils.get_naive_sharding_spec(), + sharding_utils.get_replicated_sharding(), + sharding_utils.get_replicated_sharding()), + static_argnums=(0,), + out_shardings=sharding_utils.get_replicated_sharding(), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) @@ -119,7 +130,8 @@ def _normalize_eval_metrics( Any]) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples - total_metrics = total_metrics.reduce() + # total_metrics = total_metrics.reduce() + print(total_metrics) return {k: float(v) for k, v in total_metrics.compute().items()} diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..14666c081 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -161,6 +161,7 @@ def _eval_batch(self, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) + jax.debug.print(str(logits)) return self._eval_metric(batch['targets'], logits, batch['weights']) def _eval_model_on_split(self, diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..5e175320a 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -283,8 +283,7 @@ def ref_stats(output, refs): closest_diff = diff closest_len = reflen elif diff == closest_diff: - if reflen < closest_len: - closest_len = reflen + closest_len = min(reflen, closest_len) ngrams_ref = extract_ngrams(ref) for ngram in ngrams_ref: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..dce08a677 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -14,6 +14,7 @@ import optax from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode @@ -69,10 +70,16 @@ def compute_weighted_cross_entropy( } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + ), + static_argnums=(0,), # self + ) + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] @@ -90,29 +97,29 @@ def eval_step_pmapped( 'denominator': weight_sum, } - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: - replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_naive_sharding_spec(), # inputs + ), + static_argnums=( + 0, + 2, + )) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), + dummy_inputs = sharding_utils.shard_naive( + jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = sharding_utils.shard_naive( jnp.ones(target_shape, jnp.float32)) + initial_variables = models.Transformer(config).init( + jax.random.PRNGKey(0), dummy_inputs, dummy_targets) return initial_variables['cache'] - # eos_id, max_decode_len are constant. - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) def predict_step(self, inputs: spec.Tensor, params: spec.ParameterContainer, @@ -180,20 +187,36 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] + jitted_predict_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) - predicted = _to_host(predicted) - targets = _to_host(pred_batch['targets']) + cache = sharding_utils.shard_naive(cache) + if jitted_predict_step is None: + jitted_predict_step = jax.jit( + self.predict_step, + in_shardings=( + sharding_utils.get_naive_sharding_spec(), # inputs + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_tree(cache), # cache + ), + static_argnums=( + 3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) + predicted = jitted_predict_step(pred_batch['inputs'], + params, + cache, + decode.EOS_ID, + max_predict_length) + # predicted = _to_host(predicted) + # targets = _to_host(pred_batch['targets']) + targets = pred_batch['targets'] # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') if weights is not None: - weights = _to_host(weights) + # weights = _to_host(weights) actual_batch_size = int(weights.sum(0)[0].item()) else: actual_batch_size = len(predicted) @@ -213,7 +236,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - init_fake_batch_size = 2 + init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -235,15 +258,21 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) + inputs = jnp.ones(input_shape, jnp.float32) + targets = jnp.ones(target_shape, jnp.float32) + sharded_inputs = sharding_utils.shard_naive(inputs) + sharded_targets = sharding_utils.shard_naive(targets) + initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + sharded_inputs, + sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + params = sharding_utils.shard_replicated(initial_params) + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..6b99a25a6 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -114,6 +114,7 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' }, @@ -150,6 +151,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'lm', 'ogbg', 'wmt' ] diff --git a/datasets/README.md b/dataset/README.md similarity index 100% rename from datasets/README.md rename to dataset/README.md diff --git a/datasets/dataset_setup.py b/dataset/dataset_setup.py similarity index 89% rename from datasets/dataset_setup.py rename to dataset/dataset_setup.py index efe923dbe..747d06d27 100644 --- a/datasets/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -74,15 +74,20 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer + +import datasets as hf_datasets +from transformers import AutoTokenizer import functools +import itertools import os import shutil import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -120,6 +125,9 @@ flags.DEFINE_boolean('fastmri', False, 'If --all=false, whether or not to download FastMRI.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('imagenet', False, 'If --all=false, whether or not to download Imagenet.') @@ -183,6 +191,7 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') FLAGS = flags.FLAGS @@ -699,6 +708,93 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) +def download_finewebedu(data_dir, + tmp_dir=None, + skip_download=False, + skip_tokenization=False): + """Download FineWebEdu-10B.""" + + if not skip_download: + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ["TMPDIR"] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info("Tokenizing...") + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ], + batched=True, + batch_size=1024, + num_proc=8) + + tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + else: + tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + + # Convert to tensorflow_datasets.Dataset objects + tokenized_dataset = tokenized_dataset.to_tf_dataset() + + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, "train")) + val_dataset.save(os.path.join(data_dir, "val")) + + return + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -781,6 +877,10 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 99% rename from datasets/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index a8c5cae1d..b96881332 100644 --- a/datasets/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py diff --git a/pyproject.toml b/pyproject.toml index cc404f4b5..5e9c21f47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algoperf/_version.py" [project.optional-dependencies] # All workloads full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", ] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] @@ -96,6 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +lm = ["transformers==4.25.4", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -106,15 +107,11 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + "jax[cuda12]", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..dca9a6b95 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -7,8 +7,11 @@ import jax from jax import lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax +from algoperf import sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +53,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -77,7 +74,7 @@ def _loss_fn(params): model_state, spec.ForwardPassMode.TRAIN, rng, - update_batch_norm=True) + update_batch_norm=True,) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -90,9 +87,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -105,7 +101,7 @@ def _loss_fn(params): grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -130,7 +126,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -139,23 +134,50 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Set up mesh and sharding + mesh = sharding_utils.get_mesh() + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings= ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ), + out_shardings=( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + )) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: + if global_step % 1 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -198,13 +220,15 @@ def get_batch_size(workload_name): elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': - return 256 + return 32 elif workload_name == 'ogbg': return 512 elif workload_name == 'wmt': return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..bdeaaf95b 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -173,6 +173,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -7,8 +7,11 @@ import jax from jax import lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax +from algoperf import sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +40,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,21 +90,15 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -124,9 +121,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -164,7 +159,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -173,23 +167,54 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -239,6 +264,10 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..b99d1fd94 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -41,4 +41,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..9da67e8f9 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=False) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -168,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..533e23f2c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..1cfa5deca 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,26 +7,21 @@ import jax.numpy as jnp import optax +from algoperf import sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -49,9 +44,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Compute mean loss and grad loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -98,9 +91,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = sharding_utils.get_replicated_sharding(mesh) # No partitioning + sharded = sharding_utils.get_naive_sharding_spec( + mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, + current_param_container, batch, rng, grad_clip, label_smoothing) # Log loss, grad_norm. @@ -108,8 +135,8 @@ def update_params( workload.metrics_logger is not None): workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,14 @@ from absl import logging import jax import tensorflow as tf + +# New PRNG implementation for correct sharding +jax.config.update('jax_default_prng_impl', 'threefry2x32') +jax.config.update('jax_threefry_partitionable', True) +# JAX compilation caching +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) import torch import torch.distributed as dist @@ -242,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -384,8 +393,9 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if False: + # if ((train_step_end_time - train_state['last_eval_time']) >= + # workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: @@ -703,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..f576d136b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -225,7 +225,7 @@ def _build_input_queue(self, *args, **kwargs): del kwargs np.random.seed(42) - if framework == 'jax' or USE_PYTORCH_DDP: + if USE_PYTORCH_DDP: batch_shape = (n_gpus, global_batch_size // n_gpus) else: batch_shape = (global_batch_size,) @@ -422,6 +422,10 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') + elif global_batch_size < n_gpus and FLAGS.framework == 'jax': + raise ValueError( + 'Global batch size cannot be smaller than the number of GPUs when using JAX sharding.' + ) workload = _make_one_batch_workload(workload_class, workload_name, framework,