Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State maintenance in DP #565

Closed
S-aiueo32 opened this issue Dec 2, 2019 · 8 comments
Closed

State maintenance in DP #565

S-aiueo32 opened this issue Dec 2, 2019 · 8 comments
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on strategy: dp (removed in pl) DataParallel
Milestone

Comments

@S-aiueo32
Copy link
Contributor

S-aiueo32 commented Dec 2, 2019

In many image generation tasks with GANs, generator and discriminator is trained through the same generated image single iteration.
In PyTorch Lightning, the procedure is written like below:

def training_step(self, batch, batch_nb, optimizer_i):
    foo = batch['foo']
    bar = batch['bar']

    if optimizer_i == 0:  # train discriminator
        self.foo_out = self.netG(foo)  # register as a instance variable

        # calc d_loss
        d_loss = ...

        return {'loss': d_loss}

    elif optimizer_i == 1:  # train generator
        # common reconstruction error
        g_loss = F.l1_loss(self.foo_out, bar)
        # other losses
        ...

        return {'loss': g_loss}

It works well on single GPU, however, self.foo_out has been flushed in optimizer_i == 1 branch when DP is set.

I think it is a undesired behavior, any help or fix?

@S-aiueo32 S-aiueo32 added the bug Something isn't working label Dec 2, 2019
@williamFalcon
Copy link
Contributor

@S-aiueo32 yeah, this is a limitation of PyTorch. I've been looking at how to maintain state when using DP but there seems to be no clear way...

@pietern I think we talked about this a few months ago. Any suggestions on how to maintain state when using DP?

@pietern
Copy link

pietern commented Dec 4, 2019

DP replicates the source module for every call to forward. If you want to maintain state, you can't do this and rather should replicate once and then broadcast parameters and buffers from module[0] to the others. See torch/nn/parallel/{data_parallel,replicate}.py for more details. You'll see a section that broadcasts and sets the parameters/buffers. That's what still needs to be done for every iteration. The part that runs _replicate_for_data_parallel is what you'd want to skip.

@S-aiueo32
Copy link
Contributor Author

@williamFalcon @pietern
Thank you for the polite explanation.
I understood the limitations and that it is not avoidable as long as LightningModule inherits nn.Module.

@williamFalcon
Copy link
Contributor

actually, it should be avoidable given the explanation above. we just need to make the appropriate changes to the dp subclass

@williamFalcon williamFalcon reopened this Dec 7, 2019
@williamFalcon williamFalcon added feature Is an improvement or enhancement help wanted Open to be worked on and removed bug Something isn't working labels Dec 7, 2019
@pietern
Copy link

pietern commented Dec 19, 2019

This should be a companion class to nn.DataParallel. I don't want to change the behavior of the existing wrapper because I'm sure folks depend on replicating the model on every call to forward. It shouldn't be too hard though, and can use nn.DataParallel as a starting point.

@BradSegal
Copy link

Just wanted to check if there was any update/advice on this type of issue? I've got a similar situation with a GAN producing images in the first optimizer iteration then using them to update the discriminator in the second. It works well on a single GPU, but when distributing I run into the same issue. I initially thought adding the property as a buffer would maintain it, but it seems to be flushed when using DP in the same way. Is the only solution to run the generator in the discriminator's optimizer iteration?

@stale
Copy link

stale bot commented May 9, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label May 9, 2020
@stale stale bot closed this as completed May 18, 2020
@edenlightning edenlightning changed the title DP flushes instance variables of LightningModules State maintenance in DP Nov 5, 2020
@edenlightning edenlightning added strategy: dp (removed in pl) DataParallel and removed won't fix This will not be worked on labels Nov 5, 2020
@edenlightning edenlightning reopened this Nov 5, 2020
@stale
Copy link

stale bot commented Dec 5, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Dec 5, 2020
@stale stale bot closed this as completed Dec 12, 2020
@edenlightning edenlightning removed the won't fix This will not be worked on label Feb 22, 2021
@edenlightning edenlightning added this to the 1.3 milestone Feb 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on strategy: dp (removed in pl) DataParallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants