Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to specify which particular networks/weights to save when training GANs #3022

Closed
kushalchordiya216 opened this issue Aug 17, 2020 · 6 comments
Labels
question Further information is requested

Comments

@kushalchordiya216
Copy link

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

if my model consists of my models consist of several independent modules, how to save only specific modules during training. For instance, I have the following two models

Code

class PreTrainGenModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.netG = Generator()
        self.VGG = PerceptionNet()

class GAN(pl.LightningModule):
    def __init__(self, hparams):
          super(GAN, self).__init__()
          self.hparams = hparams
          self.netG: nn.Module = Generator()
          self.netD: nn.Module = Discriminator()
          self.perceptual = PerceptionNet()

Generator, Discriminator and PerceptionNet are ordinary PyTorch nn. Module classes, where I've defined my network architectures
The pretrainGenClass pre trains my generator (it's just an experimental approach I'm trying) The perception net model is basically just a frozen VGG_19 graph, that I'm using to calculate a content loss between the real and fake images.

I first want to train PreTrainGen, which only updates weights of the net, since perception net (VGG) graph is frozen. however, when I save it, the checkpoint state_dict also contains weights of the frozen VGG model, and hence I cannot directly load it, Generator class while training the GAN, since the model state_dict does not match the checkpoint state_dict
I'm assuming even for the GAN training, the checkpoint will save all weights, including the discriminator, which is not what I want during inference.

What have you tried?

I have considered the obvious approach of not including the PerceptionNet inside the lightning module definition, since it's not a part of the computation graph anyway, and while this admittedly solves the issue for pretraining, the saved checkpoint for GAN will still have the discriminator weights coupled with the generator weights.
I know the inference can still be done if the forward method of the lightning module is written correctly, but I would much rather prefer if there was a cleaner way to specify exactly which of the models I want to save during training of GANs or any other model that might have multiple independent modules inside.

I have also considered filtering out the parts of the checkpoint state_dict that correspond to the VGG net and only load the netG weights, but that approach seems even clunkier.
Any help would be greatly appreciated, I'm still new to pytorch lightning, but I've scoured the docs and examples, and haven't found anything that answers my questions yet

What's your environment?

  • OS: Linux
  • Packaging pip
  • Version 0.9.0.rc15
@kushalchordiya216 kushalchordiya216 added the question Further information is requested label Aug 17, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@rohitgr7
Copy link
Contributor

You can try on_save_checkpoint. It's a dict and contains model state_dict in the 'state_dict' key. You can alter it there and reassign it to the checkpoint again.
https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#on-save-checkpoint

@kushalchordiya216
Copy link
Author

kushalchordiya216 commented Aug 18, 2020

You can try on_save_checkpoint. It's a dict and contains model state_dict in the 'state_dict' key. You can alter it there and reassign it to the checkpoint again.
https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#on-save-checkpoint

hey, thanks for the quick answer, I checked it out, and while that is helpful, if I understand correctly, it only helps me add things while saving a checkpoint, whereas I'm wondering if I can remove certain aspects of it, like for example, don't save the discriminator while training a GAN. Helpful hint either way so thanks

@rohitgr7
Copy link
Contributor

the checkpoint is a dict with state_dict. you just ripoff the discriminator from state_dict and assign it back to checkpoint['state_dict'] = new_weights.

@kushalchordiya216
Copy link
Author

oh damn, never thought of that. That would work, thanks! I'll be closing the issue now, thank you for your help

@etienne87
Copy link

this links points to nowhere, any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants