Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support replica groups for distributed batchnorm #42

Closed
wants to merge 1 commit into from

Conversation

jekbradbury
Copy link
Contributor

Requires jax-ml/jax#2382. Needed for large scale (i.e., small per-device batch size) training of ResNets, where the ideal number of examples to normalize over seems to be about 128 (and normalizing over the whole pmap is both unnecessarily slow and gives worse results).

@jheek
Copy link
Member

jheek commented Mar 9, 2020

The changes look good to me.
Do you know if this is possible to test this on travis CPU?

@@ -65,6 +66,8 @@ def apply(self,
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
replica_groups: the custom replica groups used to combine batch statistics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is known to everyone but me, but maybe we can document better what "replica groups" are?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should refer to lax.psum/pmean for more details

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked up the documentation and source code of lax.psum but I can't find the term "replica groups" there.

Apologies as maybe I am "new here" but give that JAX lacks internal docstrings, I think we need to own explaining this in the Flax API, unless we can point people to other JAX references (which would be better!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The replica_groups kwarg is added in jax-ml/jax#2382. We might end up using a different name, though. (The name and concept are definitely not known to everyone, but the idea is something many people want to express: doing batch normalization over more than just the examples on each accelerator alone, but less than the entire global batch). TF code often uses a distributed_group_size keyword argument and then converts that into XLA replica groups later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Perhaps a link to that pull request in the docstring is a simple solution?

@jheek
Copy link
Member

jheek commented Mar 13, 2020

Actually I have one more question about this: What is the benefit of using replica_groups over a nested pmap? You could do something like pmap(pmap(train_step, 'bn_group'), 'batch') and use BatchNorm(axis_name='bn_group') to get a similar result, correct?

@jekbradbury
Copy link
Contributor Author

Yes—the difference is only that dist BN is a library feature that ideally wouldn’t affect the top-level training loop. But I’ll try seeing if I can do what I need with nested pmap; looks like that might be easier than wiring through replica group support in a way we’re fully happy with.

@jheek jheek closed this Mar 27, 2020
@jheek jheek reopened this Mar 27, 2020
@jheek jheek changed the base branch from prerelease to master March 27, 2020 11:40
@avital
Copy link
Contributor

avital commented Mar 29, 2020

Hi @jekbradbury -- what's the latest on this? Have you been able to use nested pmaps or is this PR still necessary for the training that you're doing?

@avital
Copy link
Contributor

avital commented Apr 24, 2020

Looks like this PR may be stale. I'll close it for now, but feel free to re-open with additional context as appropriate.

@avital avital closed this Apr 24, 2020
@avital
Copy link
Contributor

avital commented Apr 24, 2020

Looks like this PR may be stale. I'll close it for now, but feel free to re-open with additional context as appropriate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants