-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid wrapping LightningModule in *DataParallel overrides when not fitting #8632
Conversation
Hello @ninginthecloud! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found: There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-08-12 18:20:15 UTC |
Codecov Report
@@ Coverage Diff @@
## master #8632 +/- ##
======================================
- Coverage 89% 89% -0%
======================================
Files 176 176
Lines 14268 14291 +23
======================================
+ Hits 12679 12687 +8
- Misses 1589 1604 +15 |
@ananthsub @SeanNaren @ninginthecloud |
3caea11
to
4a15a61
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note:
With this change,
self.model
may not be a DistributedDataParallel model. therefore, in L400-410 of the DDP plugin, for training_step
, validation_step
, test_step
, and predict_step
- we must check if isinstance(self.model, DistributedDataParallel)
If it is, we should call self.model(*args, **kwargs)
This works because self.model
's forward function internally routes to the LightningModule's *_step function: https://github.com/PyTorchLightning/pytorch-lightning/blob/1f01db8b303e647b102f48c36e91ddb17784414f/pytorch_lightning/overrides/base.py#L77-L99
however, if isinstance(self.model, LightningModule)
then we should simply call self.model.training_step / validation_step/test_step / predict_step directly
More context on this design can be found here: #4630
@awaelchli @justusschock to double check this approach
@justusschock could you describe why the processes spawned would affect this? do you mean loading weights on rank 0 only and broadcasting weights via DDP's synchronization, and then running evaluation? I think applying the no-sync wrapper regardless for validation/test/predict should be done, and that might resolve the uneven end of data in a simpler way. what's the clearest way to add the no_sync for these? is that handled by the plugin or the loops? and is that already accounted for by the no_grad context manager applied at the start of those loops? |
4a15a61
to
5f081bc
Compare
5f081bc
to
3fc1e16
Compare
6fe7e55
to
705d663
Compare
705d663
to
98fe2b0
Compare
@ninginthecloud It's because the tests are failing and if they do, the CI job won't submit coverage. Coverage is merged from all different jobs so if say the GPU code path fails coverage for that will be missing. if tests pass the codecov bot will also update the message on this pr here. |
3869780
to
a8e1fe6
Compare
for more information, see https://pre-commit.ci
a8e1fe6
to
4bab65f
Compare
Closing this out in favor of #9096 |
What does this PR do?
Avoid wrapping LightningModule in *DataParallel overrides when not fitting
Specifically,
- Update
configure_ddp
function inDDPPlugin
,DDPSpawnPlugin
,DDPShardedPlugin
andDDPSpawnShardedPlugin
by checking the state of LightningModule and avoiding wrapping LihgningModule as *DataParallel when the state is notTrainerFn.FITTING
.- Update
validation_step
function inDDPPlugin
andDDPSpawnPlugin
to use LightningModule'svalidation_step
function ifself.model
is notDistributedDataParallel
instance.- Update
test_step
andprediction_step
functions inDDPPlugin
andDDPSpawnPlugin
to use LightningModule's*_step
functions directly.Fixes #6977
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃