-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to specify which particular networks/weights to save when training GANs #3022
Comments
Hi! thanks for your contribution!, great first issue! |
You can try |
hey, thanks for the quick answer, I checked it out, and while that is helpful, if I understand correctly, it only helps me add things while saving a checkpoint, whereas I'm wondering if I can remove certain aspects of it, like for example, don't save the discriminator while training a GAN. Helpful hint either way so thanks |
the checkpoint is a |
oh damn, never thought of that. That would work, thanks! I'll be closing the issue now, thank you for your help |
this links points to nowhere, any update? |
❓ Questions and Help
Before asking:
What is your question?
if my model consists of my models consist of several independent modules, how to save only specific modules during training. For instance, I have the following two models
Code
Generator, Discriminator and PerceptionNet are ordinary PyTorch nn. Module classes, where I've defined my network architectures
The pretrainGenClass pre trains my generator (it's just an experimental approach I'm trying) The perception net model is basically just a frozen VGG_19 graph, that I'm using to calculate a content loss between the real and fake images.
I first want to train PreTrainGen, which only updates weights of the net, since perception net (VGG) graph is frozen. however, when I save it, the checkpoint state_dict also contains weights of the frozen VGG model, and hence I cannot directly load it, Generator class while training the GAN, since the model state_dict does not match the checkpoint state_dict
I'm assuming even for the GAN training, the checkpoint will save all weights, including the discriminator, which is not what I want during inference.
What have you tried?
I have considered the obvious approach of not including the PerceptionNet inside the lightning module definition, since it's not a part of the computation graph anyway, and while this admittedly solves the issue for pretraining, the saved checkpoint for GAN will still have the discriminator weights coupled with the generator weights.
I know the inference can still be done if the forward method of the lightning module is written correctly, but I would much rather prefer if there was a cleaner way to specify exactly which of the models I want to save during training of GANs or any other model that might have multiple independent modules inside.
I have also considered filtering out the parts of the checkpoint state_dict that correspond to the VGG net and only load the netG weights, but that approach seems even clunkier.
Any help would be greatly appreciated, I'm still new to pytorch lightning, but I've scoured the docs and examples, and haven't found anything that answers my questions yet
What's your environment?
The text was updated successfully, but these errors were encountered: