Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LightningDistributedDataParallel #5185

Merged
merged 22 commits into from
Jan 13, 2021

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Dec 18, 2020

What does this PR do?

Fixes #4630
cc @ananthsub

The motivation behind this refactor of the DDP wrapper is that we get all (future) improvements from upstream pytorch DDP, and the user can comfortably subclass the pytorch wrapper if they want to.

TODO:

  • Docs
  • add unit tests for new class to keep coverage up
  • add deprecation for old class?

@awaelchli awaelchli added refactor design Includes a design discussion labels Dec 18, 2020
@awaelchli awaelchli added this to the 1.2 milestone Dec 18, 2020
@awaelchli awaelchli changed the base branch from master to release/1.2-dev December 18, 2020 19:20
@awaelchli awaelchli changed the base branch from release/1.2-dev to master December 18, 2020 19:21
@awaelchli awaelchli force-pushed the refactor/distrib-wrapper branch from 57fb035 to 7655c37 Compare December 18, 2020 19:24
@awaelchli awaelchli changed the base branch from master to release/1.2-dev December 18, 2020 19:24
@awaelchli awaelchli force-pushed the refactor/distrib-wrapper branch from 7655c37 to ef34dc1 Compare December 18, 2020 19:37
@awaelchli awaelchli changed the title Refactor LightningDistributedDataParallel [skip-ci] Refactor LightningDistributedDataParallel [skip ci] Dec 18, 2020
@awaelchli awaelchli added the distributed Generic distributed-related topic label Dec 18, 2020

def reducer_prepare_for_backwards(self, output):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where this should go, it requires the reducer from ddp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren do you remember the conversation we started in #4976 about this? You had an idea there, I'm trying to understand it, maybe you can explain again? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been thinking about this one maybe too much ahah, but I didn't find a better way to do it as backward and optimizer.step are being called in training_step and DDP reducer is being called on training_step output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could open a PR in PyTorch to at least move it to a function we can use

@ananthsub
Copy link
Contributor

fyi @pritamdamania87

Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exciting progress! are there other spots in lightning where LightningDistributedDataParallel is referenced? i imagine this will also provide a perf benefit as we can take advantage of improvements in DDP's forward automatically now.

  • Does the same pattern apply for data parallel?
  • (down the line) what if someone wants to wrap only some modules in their lightning module with DDP? what further changes would that require here?

Comment on lines 146 to 171
@contextmanager
def block_backward_sync(self, model: LightningDistributedDataParallel):
def block_backward_sync(self, model: DistributedDataParallel):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
yield model.no_sync()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need block_backward_sync still? can we directly call model.no_sync()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only reason I can think of why it is there is so the user can override this method in their own plugin, though there is not much customization they can do to this context manager :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We basically added this, since we did not want anything that is only DDP specific (i.e. any typechecks against the prior LightningDistributedDataParallel within the trainer/training-loop as this one should be backend agnostic.


def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
class LightningDistributedWrapper(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for later: we should add comments for how this class is to be used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, with example.

model = LightningDistributedDataParallel(
model,
model = DistributedDataParallel(
module=LightningDistributedWrapper(model),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring above needs to be updated for the return type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -63,8 +64,8 @@ def configure_ddp(self, model, device_ids):
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why this is the default. this incurs a perf hit and is different from the DDP default

Copy link
Contributor Author

@awaelchli awaelchli Dec 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added in this PR by you: #4382, I'm not sure if it's necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @awaelchli. Not sure Will looked into the default for find_unused_parameters. Let's stick to Pytorch default which is False right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be false, it is only recommended to do this if necessary: https://pytorch.org/docs/stable/notes/ddp.html#internal-design

Copy link
Contributor Author

@awaelchli awaelchli Jan 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, made it default False to be in line with pytorch DDP: #5435

Copy link
Contributor

@ananthsub ananthsub Jan 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in #4382 I was preserving the prior behavior without digging into the full history behind the setting :/

This could be a nice speedup for distributed training jobs. @SeanNaren n00b question: is there a way to estimate the possible gains using the lightning benchmarks?

@awaelchli
Copy link
Contributor Author

are there other spots in lightning where LightningDistributedDataParallel is referenced?

in the end I will do a search across the codebase, but I think it's more or less isolated to the plugins and accelerators and in most places in Lightning we operate on model.module.

Does the same pattern apply for data parallel?

yes I believe so. currently the biggest difference I see is the custom gather function for Results objects, but since we dropped support we should also be able to get rid of it and just use the native torch implementation for DP.

(down the line) what if someone wants to wrap only some modules in their lightning module with DDP? what further changes would that require here?

oh, that sounds adventurous! Not sure if in this scenario the wrapper in this PR would stand in the way, because of the order of nesting it implies. Need to think about it more.

At the moment the PR is blocked by the problem of manual optimization. It seems the custom code for calling the reducer is necessary for Lightning's manual backward. I don't fully understand it, I hope @SeanNaren can give me some hints.

@SeanNaren
Copy link
Contributor

SeanNaren commented Dec 27, 2020

Nice! I like the overall flow of this. Just to make it clear for me, all we're doing is adding a higher level torch module on top of the DDP/DP modules? I have to ask; was there any reason this wasn't done initially? I may have missed something in this case.

Regarding manual optimization, the additional logic is due to the backward all reduce hooks being added after forward is finished. Right now we assume that if you override training_step in manual optimization, you can do everything you like within the function. Before the fix, this meant that if you called .backward() in training_step, gradients were not synced as no autograd reduce hooks were added.

We added the additional logic as a temp fix for this, but I really don't like the additional logic. I haven't had time to figure out a cleaner solution than what we added, but our current solution boils down to calling reducer.prepare_for_backwards(...) manually which adds the hooks (optionally using the loss to add hooks to parameters that actually require it iirc).

@tchaton
Copy link
Contributor

tchaton commented Jan 4, 2021

Nice! I like the overall flow of this. Just to make it clear for me, all we're doing is adding a higher level torch module on top of the DDP/DP modules? I have to ask; was there any reason this wasn't done initially? I may have missed something in this case.

Regarding manual optimization, the additional logic is due to the backward all reduce hooks being added after forward is finished. Right now we assume that if you override training_step in manual optimization, you can do everything you like within the function. Before the fix, this meant that if you called .backward() in training_step, gradients were not synced as no autograd reduce hooks were added.

We added the additional logic as a temp fix for this, but I really don't like the additional logic. I haven't had time to figure out a cleaner solution than what we added, but our current solution boils down to calling reducer.prepare_for_backwards(...) manually which adds the hooks (optionally using the loss to add hooks to parameters that actually require it iirc).

Hey @awaelchli, I can help you out for manual_optimization and Pipe. Would be great to finalise this work.

For PipeRPCPlugin, we just need to replace PREPARE_FOR_BACKWARD into require_backward_grad_sync.

For manual_optimization, we need to set require_backward_grad_sync=False by default.

And update Accelerator.backward function. Still hacky ...

def backward(self, closure_loss, optimizer, opt_idx, *args, should_sync = True, **kwargs):
    automatic_optimization = self.trainer.train_loop.automatic_optimization

    if not automatic_optimization and self.ddp_plugin is not None:
        if should_sync:
            if torch.is_grad_enabled():
                if self.find_unused_parameters:
                    self.reducer.prepare_for_backward(list(_find_tensors(closure_loss)))
                else:
                    self.reducer.prepare_for_backward([])
            self.require_backward_grad_sync = True
            self.require_forward_param_sync = True

        else:
            self.require_backward_grad_sync = False
            self.require_forward_param_sync = False

And it could be use this way.

def training_step(....):
    ...
    make_optimizer_step=(batch_idx % 2 == 0)
    self.manual_backward(loss, opt, should_sync=make_optimizer_step)
    opt.step(make_optimizer_step=make_optimizer_step)

However, we will have to be careful if the user provides a closure as the accumulated_gradient logic is already within the LightningOptimizer.

def training_step(....):
    ...
    def closure():
          loss = ...
          # should_sync should be inferred from LightningOptimizer 
          self.manual_backward(loss, opt)
    opt.step(closure=closure. make_optimizer_step=(batch_idx % 2 == 0))

NB: require_backward_grad_sync and require_forward_param_sync was already present in July 2019 v1.3.0.
So it might work

Best,
T.C

@Borda Borda added the priority: 1 Medium priority task label Jan 5, 2021
@Borda
Copy link
Member

Borda commented Jan 5, 2021

@awaelchli is this ready to land/review? 🐰

@tchaton
Copy link
Contributor

tchaton commented Jan 8, 2021

Hey @awaelchli,

Resolved on this checkout from this branch: #5415

Best,
T.C

@tchaton tchaton mentioned this pull request Jan 8, 2021
12 tasks
@awaelchli
Copy link
Contributor Author

awaelchli commented Jan 8, 2021

@tchaton amazing, thanks!

Any preferences/suggestions for a better name?

  • LightningDistributedWrapper (current)
  • LightningDistributedModule? (maybe more in line with LightningModule)

@awaelchli awaelchli changed the title Refactor LightningDistributedDataParallel [skip ci] Refactor LightningDistributedDataParallel Jan 8, 2021
@awaelchli awaelchli marked this pull request as ready for review January 8, 2021 13:08
@codecov
Copy link

codecov bot commented Jan 8, 2021

Codecov Report

Merging #5185 (1c90586) into release/1.2-dev (61f415f) will decrease coverage by 0%.
The diff coverage is 100%.

@@               Coverage Diff                @@
##           release/1.2-dev   #5185    +/-   ##
================================================
- Coverage               93%     93%    -0%     
================================================
  Files                  152     151     -1     
  Lines                10737   10616   -121     
================================================
- Hits                  9950    9828   -122     
- Misses                 787     788     +1     

pytorch_lightning/overrides/data_parallel.py Outdated Show resolved Hide resolved
Override the forward call in lightning so it goes to training and validation step respectively
"""
PREPARE_FOR_BACKWARDS = True
class LightningDistributedWrapper(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, this smells like api change ;]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, deprecation warning and remove in 1.4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda I also added a deprecation test, but it is not so simple because it needs torch.distributed to be initialized, so it will add about 3-4 seconds only for that test. Don't know how to do it simpler

Copy link
Contributor

@SeanNaren SeanNaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this, great cleanup @awaelchli. Something similar could potentially happen to the ShardedDataParallel class in override/fairscale.py but that can be a separate PR.

I also don't mind the name, but I think it's more pytorch-esque if we go for LightningDistributedModule. I don't really mind, either way is clear :)


def reducer_prepare_for_backwards(self, output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been thinking about this one maybe too much ahah, but I didn't find a better way to do it as backward and optimizer.step are being called in training_step and DDP reducer is being called on training_step output

@Borda Borda force-pushed the refactor/distrib-wrapper branch from 4eadb89 to be8e11e Compare January 13, 2021 13:15
Comment on lines +17 to +18
pl_module.training = True
pl_module.testing = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious how the module has these attributes, isn't it Trainer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my answer here
The property is copied to the model here
happy to follow up on this in a next step. The test that I wrote there is just to make sure it works the same as before refactor.

@tchaton tchaton merged commit e806bb7 into release/1.2-dev Jan 13, 2021
@tchaton tchaton deleted the refactor/distrib-wrapper branch January 13, 2021 19:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion distributed Generic distributed-related topic has conflicts priority: 1 Medium priority task ready PRs ready to be merged refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Keeping DDP override in sync with upstream torch
9 participants