Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP state_dict_type 'sharded' not working with ModelCheckpoint(save_weights_only=True) #19492

Closed
dimitri-voytan opened this issue Feb 16, 2024 · 2 comments · Fixed by #19524
Closed
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: fsdp Fully Sharded Data Parallel
Milestone

Comments

@dimitri-voytan
Copy link
Contributor

dimitri-voytan commented Feb 16, 2024

Bug description

When training with FSDPStrategy(state_dict_type='sharded') with a checkpoint callback that saves only the weights, checkpoint_callback = ModelCheckpoint(dirpath='.', save_weights_only=True) , the run fails with

lightning/pytorch/strategies/fsdp.py", line 577, in save_checkpoint
{f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))}
KeyError: 'optimizer_states'

The fix could be very simple. In lightning/pytorch/strategies/fsdp.py add a check for "optimizer_states"

Currently we have (line 574 on):

converted_state = {"model": checkpoint.pop("state_dict")}
converted_state.update(
    {f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))}
     )

With correction

converted_state = {"model": checkpoint.pop("state_dict")}
if "optimizer_states" in checkpoint.keys: #check for optimizer_states
    converted_state.update(
          {f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))}
     )

Happy to submit a PR if there are no problems with this proposed solution. Can create unit tests etc.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.trainer.trainer import Trainer
from lightning import LightningModule
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch import nn
import torch
from torchvision.datasets import MNIST
import os
from torchvision import transforms
from lightning.pytorch.strategies import FSDPStrategy

# Dummy example from the docs
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)
    
    
class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# define dataset
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

# Setup model checkpointing. Save only the weights
checkpoint_callback = ModelCheckpoint(dirpath='.', save_weights_only=True)

model = LitAutoEncoder(encoder=Encoder(), decoder=Decoder())

# Setup FSDP training with sharded state_dict_type
strategy = FSDPStrategy(state_dict_type='sharded')

# Create trainer. Only use a little bit of data
trainer = Trainer(strategy=strategy, callbacks=checkpoint_callback,
                  limit_train_batches=0.05)

trainer.fit(model, train_dataloaders=train_loader)

Error messages and logs

/lightning/pytorch/strategies/fsdp.py", line 577, in save_checkpoint
{f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))}
KeyError: 'optimizer_states'

Environment

Current environment
#- Lightning Component (e.g. FSDP, ModelCheckpoint):
#- PyTorch Lightning Version ('2.2.0.post0'):
#- PyTorch Version (2.2.0+cu118):
#- Python version (3.10.12):
#- OS (Linux):
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration:
	- GPU:
		- NVIDIA A100-PCIE-40GB
		- NVIDIA A100-PCIE-40GB
		- NVIDIA A100-PCIE-40GB
#- How you installed Lightning(`conda`, `pip`, source):
      pip
#- Running environment of LightningApp (e.g. local, cloud):
      local

More info

No response

cc @awaelchli @carmocca

@dimitri-voytan dimitri-voytan added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 16, 2024
@carmocca carmocca added checkpointing Related to checkpointing strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Feb 17, 2024
@awaelchli
Copy link
Contributor

@dimitri-voytan Thanks. An edge case we haven't considered. Please feel free to submit a PR if you like. I think the fix could even be simpler by just doing checkpoint.pop("optimizer_states", []) instead of checkpoint.pop("optimizer_states").

@awaelchli awaelchli added this to the 2.2.x milestone Feb 23, 2024
@dimitri-voytan
Copy link
Contributor Author

Thanks, working on it now, will submit tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants