Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support DiLoCo training #1018

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

[WIP] Support DiLoCo training #1018

wants to merge 1 commit into from

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Nov 7, 2024

DiLoCo training paper: https://arxiv.org/pdf/2311.08105

Adds preliminary support for DiLoCo training through DrJax.

"""
import drjax

# TODO(jonbolin): Keep this as part of DiLoCoTrainState?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question - some configurability here is likely nice. I think it's possible that people might want to experiment with other outer optimizers, though sgd + nesterov momentum is a really competitive baseline.

# to vmap over.
per_replica_batch = config.global_batch_size_to_train_on // config.num_diloco_replicas
batch_shape = (config.num_diloco_replicas, per_replica_batch, -1)
batch = jax.tree.map(lambda x: x.reshape(batch_shape), batch)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some versions I've written, I explicitly add this extra axis before passing to the train step, so that I can ensure that it is sharded over the diloco axis. I'm not certain if that's necessary though, but wanted to flag it for posterity.

train_step,
(state.inner_state, batch, broadcast_rng)
)
avg_metrics = typed_reduce_mean(metrics)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to call out is that differentiating between metrics that need to be summed and metrics that need to be averaged is kind of a pain.

E.g. if you're keeping track of total number of tokens seen, then a reduce_sum here is probably desired. I don't have any great solutions though without understanding better the structure of metrics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants