-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Conversation
See #3753 for why Zero3 won't be supported in this implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable - minor comments
parlai/core/params.py
Outdated
@@ -772,6 +772,16 @@ def add_distributed_training_args(self): | |||
grp.add_argument( | |||
'--distributed-world-size', type=int, help='Number of workers.' | |||
) | |||
grp.add_argument( | |||
'--ddp-backend', | |||
choices=['ddp', 'zero2', 'zero3'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm should we even give 'zero3'
as an option for the time being? (Don't really care either way)
parlai/utils/fsdp.py
Outdated
|
||
def should_sync_gradnorm(opt): | ||
""" | ||
Indicates whether fp16 optimizer wrappers should cumulate over workers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "accumulate"?
parlai/core/torch_agent.py
Outdated
|
||
For models or optimizers that shard parameters, this ensures we sync. | ||
""" | ||
if self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: should we pull in DEFAULT_DDP_BACKEND
here?
parlai/core/torch_generator_agent.py
Outdated
if ( | ||
shared is None | ||
and is_distributed() | ||
and opt.get('ddp_backend', 'ddp') == 'ddp' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same here about maybe using DEFAULT_DDP_BACKEND
instead)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really really cool. lots of nits though (and a few real questions 😄 )
@@ -1969,10 +1974,11 @@ def state_dict(self): | |||
""" | |||
states = {} | |||
if hasattr(self, 'model'): # save model params | |||
if hasattr(self.model, 'module'): | |||
# did we wrap in a DistributedDataParallel | |||
if hasattr(self.model, 'module') and not is_fsdp(self.model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could make this a helper function too? like should_sync_gradnorm
(not necessary of course)
parlai/core/torch_generator_agent.py
Outdated
self.model = self.build_model() | ||
with fsdp_utils.maybe_fsdp_wrap(opt): | ||
self.model = fsdp_utils.fsdp_wrap(self.build_model()) | ||
if self.fp16 and not fsdp_utils.should_use_fsdp(opt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remember that bug with the instability stuff? is this not re-introducing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(because we moved the model.half() call?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think this needs to use my utility should_delay_halving
. Forgot this.
We haven't really moved it the moment of halving. The operations between these two points don't do much, and the original code path should be about the same.
- We now half it on CPU instead of GPU, and then transfer. That's probably a small speedup in initialization really, with maybe some small numerical differences
- We model parallel after halving. Probably small speedup at initialization.
- We synchronize parameters after halving. Again, small initialization speedup.
The catch is that FSDP expects the model pre-halved if we're doing safe optimization, and post-halved if we're doing memory-efficient. (Similar to the optimizer wrappers, it looks for parameters of types to decide what type are the gradients).
This is the desired pattern
- If we're in Safe and using DDP, we SHOULD still halve, just as before
- If we're in MemEff and using DDP, we SHOULD still halve, just as before
- If we're in Safe and Zero2, we should NOT halve here
- If we're in MemEff and Zero2, we SHOULD halve here.
@@ -55,10 +54,12 @@ def multiprocess_train( | |||
raise | |||
|
|||
|
|||
def launch_and_train(opt, port): | |||
def launch_and_train(opt, port=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will we ever specify a port here?
@@ -543,7 +546,7 @@ def validate(self): | |||
) | |||
self.best_valid = new_valid | |||
self.impatience = 0 | |||
if opt.get('model_file') and is_primary_worker(): | |||
if opt.get('model_file'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just making sure I understand - we can get rid of this check because it's handled in save_model
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to be able do save_on_nonprimary_worker actually
if max_norm > 0: | ||
clip_coef = max_norm / (grad_norm + 1e-6) | ||
for p in params: | ||
p.grad.detach().mul_(clip_coef) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we detach here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't want grads of grads! (This is in the original pytorch code too)
return | ||
|
||
# zero3 not supported at this time. Throw an exception | ||
if opt['ddp_backend'] == 'zero3': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i know this is just for overkill testing but it's not even a choice in the param options so we'll already error there if calling from command line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm leaving it for the future
parlai/utils/fsdp.py
Outdated
return ( | ||
self.fp16 | ||
and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') | ||
and self.opt['fp16_impl'] == 'safe' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but if we're using mem_efficient
we don't delay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, see main comment
Patch description
Add support for Fairscale's FullyShardedDataParallel (FSDP). This is an implementation of DeepSpeed's Zero2 optimization, wherein optimizer state and gradients are sharded across different workers in order to reduce memory usage. Switching to
--ddp-backend zero2
results in about a 25% speedup in UPS (without bg workers, probably can be a bit higher), and about a 50% reduction in memory usage. It's recommended everyone switches to this for distributed training, and use the savings to increase batchsize or lower number of GPUs.We also carve out support for Zero3, but cannot support it at this time due to high level design in ParlAI. See #3753 for a detailed description of why, and how we might overcome this in the future.
As a side change, this also makes our unit tests use OS-assigned free ports, instead of randomized ones, to slightly improve the reliability of running our test suites. I tried pulling this into another PR, but got tired of dealing with stacking.
Testing steps
Manual tests. New CI.
Here are some screenshots from a sweep that contained both
--ddp-backend ddp
and--ddp-backend zero2