-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor LightningDistributedDataParallel #5185
Conversation
57fb035
to
7655c37
Compare
7655c37
to
ef34dc1
Compare
|
||
def reducer_prepare_for_backwards(self, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure where this should go, it requires the reducer from ddp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SeanNaren do you remember the conversation we started in #4976 about this? You had an idea there, I'm trying to understand it, maybe you can explain again? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have been thinking about this one maybe too much ahah, but I didn't find a better way to do it as backward
and optimizer.step
are being called in training_step
and DDP reducer
is being called on training_step output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could open a PR in PyTorch to at least move it to a function we can use
fyi @pritamdamania87 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exciting progress! are there other spots in lightning where LightningDistributedDataParallel
is referenced? i imagine this will also provide a perf benefit as we can take advantage of improvements in DDP's forward
automatically now.
- Does the same pattern apply for data parallel?
- (down the line) what if someone wants to wrap only some modules in their lightning module with DDP? what further changes would that require here?
@contextmanager | ||
def block_backward_sync(self, model: LightningDistributedDataParallel): | ||
def block_backward_sync(self, model: DistributedDataParallel): | ||
""" | ||
Blocks ddp sync gradients behaviour on backwards pass. | ||
This is useful for skipping sync when accumulating gradients, reducing communication overhead | ||
Returns: context manager with sync behaviour off | ||
""" | ||
yield model.no_sync() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need block_backward_sync
still? can we directly call model.no_sync()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the only reason I can think of why it is there is so the user can override this method in their own plugin, though there is not much customization they can do to this context manager :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We basically added this, since we did not want anything that is only DDP specific (i.e. any typechecks against the prior LightningDistributedDataParallel
within the trainer/training-loop as this one should be backend agnostic.
|
||
def parallel_apply(self, replicas, inputs, kwargs): | ||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) | ||
class LightningDistributedWrapper(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for later: we should add comments for how this class is to be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, with example.
model = LightningDistributedDataParallel( | ||
model, | ||
model = DistributedDataParallel( | ||
module=LightningDistributedWrapper(model), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the docstring above needs to be updated for the return type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -63,8 +64,8 @@ def configure_ddp(self, model, device_ids): | |||
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( | |||
"find_unused_parameters", True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure why this is the default. this incurs a perf hit and is different from the DDP default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was added in this PR by you: #4382, I'm not sure if it's necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @awaelchli. Not sure Will looked into the default for find_unused_parameters
. Let's stick to Pytorch default which is False right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be false, it is only recommended to do this if necessary: https://pytorch.org/docs/stable/notes/ddp.html#internal-design
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, made it default False to be in line with pytorch DDP: #5435
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in #4382 I was preserving the prior behavior without digging into the full history behind the setting :/
This could be a nice speedup for distributed training jobs. @SeanNaren n00b question: is there a way to estimate the possible gains using the lightning benchmarks?
in the end I will do a search across the codebase, but I think it's more or less isolated to the plugins and accelerators and in most places in Lightning we operate on model.module.
yes I believe so. currently the biggest difference I see is the custom gather function for Results objects, but since we dropped support we should also be able to get rid of it and just use the native torch implementation for DP.
oh, that sounds adventurous! Not sure if in this scenario the wrapper in this PR would stand in the way, because of the order of nesting it implies. Need to think about it more. At the moment the PR is blocked by the problem of manual optimization. It seems the custom code for calling the reducer is necessary for Lightning's manual backward. I don't fully understand it, I hope @SeanNaren can give me some hints. |
Nice! I like the overall flow of this. Just to make it clear for me, all we're doing is adding a higher level torch module on top of the DDP/DP modules? I have to ask; was there any reason this wasn't done initially? I may have missed something in this case. Regarding manual optimization, the additional logic is due to the backward all reduce hooks being added after forward is finished. Right now we assume that if you override We added the additional logic as a temp fix for this, but I really don't like the additional logic. I haven't had time to figure out a cleaner solution than what we added, but our current solution boils down to calling |
Hey @awaelchli, I can help you out for For PipeRPCPlugin, we just need to replace For And update Accelerator.backward function. Still hacky ...
And it could be use this way.
However, we will have to be careful if the user provides a closure as the accumulated_gradient logic is already within the LightningOptimizer.
NB: Best, |
@awaelchli is this ready to land/review? 🐰 |
Hey @awaelchli, Resolved on this checkout from this branch: #5415 Best, |
@tchaton amazing, thanks! Any preferences/suggestions for a better name?
|
Codecov Report
@@ Coverage Diff @@
## release/1.2-dev #5185 +/- ##
================================================
- Coverage 93% 93% -0%
================================================
Files 152 151 -1
Lines 10737 10616 -121
================================================
- Hits 9950 9828 -122
- Misses 787 788 +1 |
Override the forward call in lightning so it goes to training and validation step respectively | ||
""" | ||
PREPARE_FOR_BACKWARDS = True | ||
class LightningDistributedWrapper(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, this smells like api change ;]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, deprecation warning and remove in 1.4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Borda I also added a deprecation test, but it is not so simple because it needs torch.distributed to be initialized, so it will add about 3-4 seconds only for that test. Don't know how to do it simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this, great cleanup @awaelchli. Something similar could potentially happen to the ShardedDataParallel class in override/fairscale.py but that can be a separate PR.
I also don't mind the name, but I think it's more pytorch-esque if we go for LightningDistributedModule
. I don't really mind, either way is clear :)
|
||
def reducer_prepare_for_backwards(self, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have been thinking about this one maybe too much ahah, but I didn't find a better way to do it as backward
and optimizer.step
are being called in training_step
and DDP reducer
is being called on training_step output
This reverts commit 8e45151.
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
4eadb89
to
be8e11e
Compare
pl_module.training = True | ||
pl_module.testing = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious how the module has these attributes, isn't it Trainer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this PR do?
Fixes #4630
cc @ananthsub
The motivation behind this refactor of the DDP wrapper is that we get all (future) improvements from upstream pytorch DDP, and the user can comfortably subclass the pytorch wrapper if they want to.
TODO: