Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing batch_dim_name attribute #20674

Merged
merged 11 commits into from
Jan 7, 2025
12 changes: 8 additions & 4 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
return global_value


def distribute_data_input(per_process_batch, layout):
def distribute_data_input(per_process_batch, layout, batch_dim_name):
"""Distribute the input data with the corresponding layout.

Note that the inputs here is a local worker batch. Within the local worker,
Expand All @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

mesh_shape = list(layout.mesh.shape.values())
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
num_model_replicas_total = layout.mesh.shape[batch_dim_name]

mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size *= dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]

Expand Down
4 changes: 3 additions & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout)
result = backend_dlib.distribute_data_input(
per_process_batch, layout, "batch"
)

# Check the shape of the global batch array
self.assertEqual(
Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
from functools import partial

import jax
import numpy as np
Expand Down Expand Up @@ -988,15 +989,18 @@ def _get_jax_state(

def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()

if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution._batch_dim_name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this without accessing the private variable _batch_dim_name? Could we consider passing the batch_dim_name as an argument to the relevant functions? Or, maybe the distribution object provides a public method or property to access the batch dimension name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll think of a cleaner way.

The goal at this point is to get a second pair of eyes on this fix and validate it is correct. See use cases at the end of the intro paragraph. Also, since you implemented the multi-host code, could you check if this fix does not break it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll run the internal multi-host test to make sure it still works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also need more tests. The failure is not in a complex case. This should have been covered by tests. I can add the a couple of tests on 8-core TPUs, but I'll let you extend them to multi-host settings.

But right now, what is your opinion on the case where model and data parallelism are used at the same time and the "batch" dimension is also a sharding dimension for the model, as is the default for Gemma and Llama? How should data batches be split in that case ? (And I don't think my fix covers that case - I'm not sure I understand how that case makes sense..).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the internal multi-host test with your changes and it passed! I think we should be able to merge this PR after updating private variable usage (_batch_dim_name).

Sharding along the batch dimension should work! Our multi-host tests test that and they pass! We test all the following configs:

  @parameterized.named_parameters([
      ("data_only", (8, 1), 2, False,),
      ("data_model", (4, 2), 2, False,),
      ("model_data", (2, 4), 4, False,),
      ("model_only", (1, 8), 8, True,),
  ])

Could you point me to the colab that shows sharding along the batch dimension doesn't work for a 2D mesh?

I think what is not supported yet is 3D+ mesh. I agree that this would be a great feature to have. Maybe we can create a feature request issue and plan for supporting it.

PS: US holidays will start tomorrow and I'll be back after the new year! Happy Holidays, Martin!

Copy link
Contributor Author

@martin-gorner martin-gorner Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the repro colabs are in the intro.

The tests may be passing but if we don't understand the use case, it could be by accident. The thing I do not understand and that the fix does not cover is:

num_model_replicas_total = layout.mesh.shape[batch_dim_name]
mesh_model_dim_size = nb_devices / num_model_replicas_total # not actual code but it amounts to this

It seems to me that these expressions assume that the model is NOT sharded on the 'batch' dimension. It is only when the model is replicated on the 'batch' dimension and sharded on all other dimensions that the expression num_model_replicas_total = layout.mesh.shape[batch_dim_name] is true. If the model is also sharded on the 'batch' dimension, I'm not sure how many model replicas there are ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model should not be sharded on the batch dimension. Model should only be sharded on the model dimension. Data should be sharded on the batch dimension. The reason that the number of model replicas is the same as the batch dimension (number of data shards) is that when we shard the data, for each shard, we need the full replication of the model to process that data shard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to agree with you but the default sharding for Gemma does shard the model on the "batch" dimension. See here. The code references this article as a rationale.

Copy link
Contributor Author

@martin-gorner martin-gorner Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the article and the only mention of sharding or replication I could find is this sentence "Within a pod, we use 16-way model sharding and 16-way data replication for the 7B model." It does not say anything about sharding the model on the batch dimension.

The sharding on the batch dimension was added by Scott in PR1491. The PR discussion says "The new setting is based on the Gemma training script internally". Can you find those scripts to investigate? Are there other places we could check this? Maybe some JAX Gemma fine-tuning scripts?

Anyway, in the short term, I think we should just remove the batch dimension from the default layout map and then safely assume that the model is NOT sharded on the batch dim. This should work, I think, even for 3D sharding (model sharded on "model" and "sequence" dims, while data is sharded on the "batch" dim).

)
return tree.map_structure(jax_dist_data_input, data, layouts)

return tree.map_structure(jax.device_put, data)

Expand Down
Loading