-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support replica groups for distributed batchnorm #42
Conversation
The changes look good to me. |
@@ -65,6 +66,8 @@ def apply(self, | |||
scale_init: initializer for scale, by default, one. | |||
axis_name: the axis name used to combine batch statistics from multiple | |||
devices. See `jax.pmap` for a description of axis names (default: None). | |||
replica_groups: the custom replica groups used to combine batch statistics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is known to everyone but me, but maybe we can document better what "replica groups" are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should refer to lax.psum/pmean for more details
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked up the documentation and source code of lax.psum
but I can't find the term "replica groups" there.
Apologies as maybe I am "new here" but give that JAX lacks internal docstrings, I think we need to own explaining this in the Flax API, unless we can point people to other JAX references (which would be better!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The replica_groups
kwarg is added in jax-ml/jax#2382. We might end up using a different name, though. (The name and concept are definitely not known to everyone, but the idea is something many people want to express: doing batch normalization over more than just the examples on each accelerator alone, but less than the entire global batch). TF code often uses a distributed_group_size
keyword argument and then converts that into XLA replica groups later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Perhaps a link to that pull request in the docstring is a simple solution?
Actually I have one more question about this: What is the benefit of using |
Yes—the difference is only that dist BN is a library feature that ideally wouldn’t affect the top-level training loop. But I’ll try seeing if I can do what I need with nested pmap; looks like that might be easier than wiring through replica group support in a way we’re fully happy with. |
Hi @jekbradbury -- what's the latest on this? Have you been able to use nested |
Looks like this PR may be stale. I'll close it for now, but feel free to re-open with additional context as appropriate. |
Looks like this PR may be stale. I'll close it for now, but feel free to re-open with additional context as appropriate. |
Requires jax-ml/jax#2382. Needed for large scale (i.e., small per-device batch size) training of ResNets, where the ideal number of examples to normalize over seems to be about 128 (and normalizing over the whole pmap is both unnecessarily slow and gives worse results).