diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index edb1dc1184a6..5dc5c057d29e 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout): return global_value -def distribute_data_input(per_process_batch, layout): +def distribute_data_input(per_process_batch, layout, batch_dim_name): """Distribute the input data with the corresponding layout. Note that the inputs here is a local worker batch. Within the local worker, @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout): if not isinstance(layout, jax.sharding.Sharding): layout = _to_jax_layout(layout) - mesh_shape = list(layout.mesh.shape.values()) - num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh - mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1 + num_model_replicas_total = layout.mesh.shape[batch_dim_name] + + mesh_model_dim_size = 1 + for name, dim_size in layout.mesh.shape.items(): + if not name == batch_dim_name: + mesh_model_dim_size *= dim_size + num_model_replicas_per_process = num_model_replicas_total / num_processes() per_process_batch_size = per_process_batch.shape[0] diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 5ab8eeb41332..81ceddfd305b 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -337,7 +337,9 @@ def test_distribute_data_input(self): mesh, jax.sharding.PartitionSpec("batch", None) ) - result = backend_dlib.distribute_data_input(per_process_batch, layout) + result = backend_dlib.distribute_data_input( + per_process_batch, layout, "batch" + ) # Check the shape of the global batch array self.assertEqual( diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index f5ae91ea8d81..c127f7f83344 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -1,5 +1,6 @@ import collections import itertools +from functools import partial import jax import numpy as np @@ -988,15 +989,18 @@ def _get_jax_state( def _distribute_data(data, layouts=None): distribution = distribution_lib.distribution() + if distribution is not None: if layouts is None: layouts = tree.map_structure( lambda d: distribution.get_data_layout(d.shape), data, ) - return tree.map_structure( - jax_distribution_lib.distribute_data_input, data, layouts + jax_dist_data_input = partial( + jax_distribution_lib.distribute_data_input, + batch_dim_name=distribution.batch_dim_name, ) + return tree.map_structure(jax_dist_data_input, data, layouts) return tree.map_structure(jax.device_put, data) diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py index b4736426afec..1528fa8fc151 100644 --- a/keras/src/distribution/distribution_lib.py +++ b/keras/src/distribution/distribution_lib.py @@ -287,8 +287,9 @@ class Distribution: device_mesh: A `DeviceMesh` instance. """ - def __init__(self, device_mesh): + def __init__(self, device_mesh, batch_dim_name=None): self._device_mesh = device_mesh + self._batch_dim_name = batch_dim_name def get_data_layout(self, data_shape): """Retrieve the `TensorLayout` for the input data. @@ -341,6 +342,10 @@ def scope(self): def device_mesh(self): return self._device_mesh + @property + def batch_dim_name(self): + return self._batch_dim_name + def distribute_dataset(self, dataset): """Create a distributed dataset instance from the original user dataset. @@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True): else: self._initialize_mesh_from_list_devices() - self._batch_dim_name = self.device_mesh.axis_names[0] # Those following attributes might get convert to public methods. self._num_process = distribution_lib.num_processes() self._process_id = distribution_lib.process_id() @@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh): "Expect `mesh` to be an instance of `DeviceMesh`. " f"Received: mesh={device_mesh} (of type {type(device_mesh)})" ) - super().__init__(device_mesh) + super().__init__(device_mesh, device_mesh.axis_names[0]) if self.device_mesh.devices.ndim != 1: warnings.warn( "Expect the input mesh to be 1D, but received " @@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices): axis_names=[DEFAULT_BATCH_DIM_NAME], devices=devices, ) - super().__init__(device_mesh) + super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME) def _initialize_mesh_from_list_devices(self): devices = np.array(list_devices()) @@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self): axis_names=[DEFAULT_BATCH_DIM_NAME], devices=devices, ) - super().__init__(device_mesh) + super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME) def get_data_layout(self, data_shape): data_shard_spec = [None] * len(data_shape) - data_shard_spec[0] = self._batch_dim_name # Shard on the first dim + data_shard_spec[0] = self.batch_dim_name # Shard on the first dim return TensorLayout(data_shard_spec, self.device_mesh) def get_variable_layout(self, variable): @@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs): def get_data_layout(self, data_shape): data_shard_spec = [None] * len(data_shape) - data_shard_spec[0] = self._batch_dim_name # Shard on the first dim + data_shard_spec[0] = self.batch_dim_name # Shard on the first dim return TensorLayout(data_shard_spec, self.device_mesh) def get_variable_layout(self, variable): @@ -631,7 +635,7 @@ def distribute_dataset(self, dataset): # Note that this might be smaller than one if model replicas are sharded # across multiple processes. mesh_batch_dim_index = self.device_mesh.axis_names.index( - self._batch_dim_name + self.batch_dim_name ) num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index] if num_model_replicas == 1: diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py index 8fd0988aec32..fba998fae461 100644 --- a/keras/src/distribution/distribution_lib_test.py +++ b/keras/src/distribution/distribution_lib_test.py @@ -186,7 +186,7 @@ def test_create_with_device_mesh(self): device_mesh = distribution.device_mesh self.assertEqual(len(device_mesh.devices), 8) self.assertEqual(device_mesh.axis_names, ["data"]) - self.assertEqual(distribution._batch_dim_name, "data") + self.assertEqual(distribution.batch_dim_name, "data") self.assertFalse(distribution._is_multi_process) self.assertEqual(distribution._process_id, 0) @@ -197,7 +197,7 @@ def test_create_with_devices(self): device_mesh = distribution.device_mesh self.assertEqual(len(device_mesh.devices), 8) self.assertEqual(device_mesh.axis_names, ["batch"]) - self.assertEqual(distribution._batch_dim_name, "batch") + self.assertEqual(distribution.batch_dim_name, "batch") @mock.patch.object( distribution_lib, @@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices): device_mesh = distribution.device_mesh self.assertEqual(len(device_mesh.devices), 8) self.assertEqual(device_mesh.axis_names, ["batch"]) - self.assertEqual(distribution._batch_dim_name, "batch") + self.assertEqual(distribution.batch_dim_name, "batch") def test_get_data_layout(self): distribution = distribution_lib.DataParallel(