-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing batch_dim_name attribute #20674
Merged
Merged
Changes from 8 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
6ecc55c
fixing wrong trainer assumption that batch dim is always the first on…
martin-gorner 830eed2
need functools partial
martin-gorner 2da7fdc
lint
martin-gorner fb6c5ae
fix test failure when distribution=None
martin-gorner 63c1f15
lint2
martin-gorner a64e093
fix for test failure
martin-gorner 64293de
added data sharding for 3D+ meshes
martin-gorner d635f47
lint3
martin-gorner 5e7b344
Merge branch 'keras-team:master' into master
martin-gorner 32f5806
added @property for batch_dim_name + refactoring
martin-gorner a7afa74
fix typo
martin-gorner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this without accessing the private variable
_batch_dim_name
? Could we consider passing thebatch_dim_name
as an argument to the relevant functions? Or, maybe thedistribution
object provides a public method or property to access the batch dimension name?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll think of a cleaner way.
The goal at this point is to get a second pair of eyes on this fix and validate it is correct. See use cases at the end of the intro paragraph. Also, since you implemented the multi-host code, could you check if this fix does not break it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll run the internal multi-host test to make sure it still works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also need more tests. The failure is not in a complex case. This should have been covered by tests. I can add the a couple of tests on 8-core TPUs, but I'll let you extend them to multi-host settings.
But right now, what is your opinion on the case where model and data parallelism are used at the same time and the "batch" dimension is also a sharding dimension for the model, as is the default for Gemma and Llama? How should data batches be split in that case ? (And I don't think my fix covers that case - I'm not sure I understand how that case makes sense..).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the internal multi-host test with your changes and it passed! I think we should be able to merge this PR after updating private variable usage (
_batch_dim_name
).Sharding along the batch dimension should work! Our multi-host tests test that and they pass! We test all the following configs:
Could you point me to the colab that shows sharding along the batch dimension doesn't work for a 2D mesh?
I think what is not supported yet is 3D+ mesh. I agree that this would be a great feature to have. Maybe we can create a feature request issue and plan for supporting it.
PS: US holidays will start tomorrow and I'll be back after the new year! Happy Holidays, Martin!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the repro colabs are in the intro.
The tests may be passing but if we don't understand the use case, it could be by accident. The thing I do not understand and that the fix does not cover is:
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
mesh_model_dim_size = nb_devices / num_model_replicas_total # not actual code but it amounts to this
It seems to me that these expressions assume that the model is NOT sharded on the 'batch' dimension. It is only when the model is replicated on the 'batch' dimension and sharded on all other dimensions that the expression
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
is true. If the model is also sharded on the 'batch' dimension, I'm not sure how many model replicas there are ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model should not be sharded on the batch dimension. Model should only be sharded on the model dimension. Data should be sharded on the batch dimension. The reason that the number of model replicas is the same as the batch dimension (number of data shards) is that when we shard the data, for each shard, we need the full replication of the model to process that data shard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to agree with you but the default sharding for Gemma does shard the model on the "batch" dimension. See here. The code references this article as a rationale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through the article and the only mention of sharding or replication I could find is this sentence "Within a pod, we use 16-way model sharding and 16-way data replication for the 7B model." It does not say anything about sharding the model on the batch dimension.
The sharding on the batch dimension was added by Scott in PR1491. The PR discussion says "The new setting is based on the Gemma training script internally". Can you find those scripts to investigate? Are there other places we could check this? Maybe some JAX Gemma fine-tuning scripts?
Anyway, in the short term, I think we should just remove the batch dimension from the default layout map and then safely assume that the model is NOT sharded on the batch dim. This should work, I think, even for 3D sharding (model sharded on "model" and "sequence" dims, while data is sharded on the "batch" dim).