-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into carmocca/mypy-1.0
- Loading branch information
Showing
62 changed files
with
654 additions
and
259 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
:orphan: | ||
|
||
############################## | ||
Multiple Models and Optimizers | ||
############################## | ||
|
||
Fabric makes it very easy to work with multiple models and/or optimizers at once in your training workflow. | ||
Examples of where this comes in handy are Generative Adversarial Networks (GANs), Auto-encoders, meta-learning and more. | ||
|
||
|
||
---- | ||
|
||
************************ | ||
One model, one optimizer | ||
************************ | ||
|
||
Fabric has a simple guideline you should follow: | ||
If you have an optimizer, you should set it up together with the model to make your code truly strategy-agnostic. | ||
|
||
.. code-block:: python | ||
import torch | ||
from lightning.fabric import Fabric | ||
fabric = Fabric() | ||
# Instantiate model and optimizer | ||
model = LitModel() | ||
optimizer = torch.optim.Adam(model.parameters()) | ||
# Set up the model and optimizer together | ||
model, optimizer = fabric.setup(model, optimizer) | ||
Depending on the selected strategy, the :meth:`~lightning.fabric.fabric.Fabric.setup` method will wrap and link the model with the optimizer. | ||
|
||
|
||
---- | ||
|
||
|
||
****************************** | ||
One model, multiple optimizers | ||
****************************** | ||
|
||
You can also have multiple optimizers over a single model. | ||
This is useful if you need specific optimizers or learning rates for parts of the model. | ||
|
||
.. code-block:: python | ||
# Instantiate model and optimizers | ||
model = LitModel() | ||
optimizer1 = torch.optim.SGD(model.layer1.parameters(), lr=0.003) | ||
optimizer2 = torch.optim.SGD(model.layer2.parameters(), lr=0.01) | ||
# Set up the model and optimizers together | ||
model, optimizer1, optimizer2 = fabric.setup(model, optimizer1, optimizer2) | ||
---- | ||
|
||
|
||
****************************** | ||
Multiple models, one optimizer | ||
****************************** | ||
|
||
Using a single optimizer to update multiple models is possible too. | ||
The best way to do this is to group all your individual models under one top level ``nn.Module``: | ||
|
||
.. code-block:: python | ||
class AutoEncoder(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
# Group all models under a common nn.Module | ||
self.encoder = Encoder() | ||
self.decoder = Decoder() | ||
Now all of these models can be treated as a single one: | ||
|
||
.. code-block:: python | ||
# Instantiate the big model | ||
autoencoder = AutoEncoder() | ||
optimizer = ... | ||
# Set up the model(s) and optimizer together | ||
autoencoder, optimizer = fabric.setup(autoencoder, optimizer) | ||
---- | ||
|
||
|
||
************************************ | ||
Multiple models, multiple optimizers | ||
************************************ | ||
|
||
You can pair up as many models and optimizers as you want. For example, two models with one optimizer each: | ||
|
||
.. code-block:: python | ||
# Two models | ||
generator = Generator() | ||
discriminator = Discriminator() | ||
# Two optimizers | ||
optimizer_gen = torch.optim.SGD(generator.parameters(), lr=0.01) | ||
optimizer_dis = torch.optim.SGD(discriminator.parameters(), lr=0.001) | ||
# Set up generator | ||
generator, optimizer_gen = fabric.setup(generator, optimizer_gen) | ||
# Set up discriminator | ||
discriminator, optimizer_dis = fabric.setup(discriminator, optimizer_dis) | ||
For a full example of this use case, see our `GAN example <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/dcgan>`_. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.