-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support DiLoCo training #1018
base: main
Are you sure you want to change the base?
Conversation
""" | ||
import drjax | ||
|
||
# TODO(jonbolin): Keep this as part of DiLoCoTrainState? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good question - some configurability here is likely nice. I think it's possible that people might want to experiment with other outer optimizers, though sgd + nesterov momentum is a really competitive baseline.
# to vmap over. | ||
per_replica_batch = config.global_batch_size_to_train_on // config.num_diloco_replicas | ||
batch_shape = (config.num_diloco_replicas, per_replica_batch, -1) | ||
batch = jax.tree.map(lambda x: x.reshape(batch_shape), batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some versions I've written, I explicitly add this extra axis before passing to the train step, so that I can ensure that it is sharded over the diloco axis. I'm not certain if that's necessary though, but wanted to flag it for posterity.
train_step, | ||
(state.inner_state, batch, broadcast_rng) | ||
) | ||
avg_metrics = typed_reduce_mean(metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to call out is that differentiating between metrics that need to be summed and metrics that need to be averaged is kind of a pain.
E.g. if you're keeping track of total number of tokens seen, then a reduce_sum here is probably desired. I don't have any great solutions though without understanding better the structure of metrics.
DiLoCo training paper: https://arxiv.org/pdf/2311.08105
Adds preliminary support for DiLoCo training through DrJax.