Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identical best monitored metric values across different depths can result in depth-aligned checkpoint metadata corruption #15

Closed
CyprienRicque opened this issue Aug 27, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@CyprienRicque
Copy link

CyprienRicque commented Aug 27, 2024

🐛 Bug

While fitting, if the current level l did not generate new checkpoints, then the trainer will move one to next level l+1 and reload the last best model saved, likely from level l-1.

Upon further tests, it turns out it is because save_last is set to True.

This loading fails with the error:

File /home/sharing/.../site-packages/lightning/pytorch/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File /home/sharing/.../site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File /home/sharing/.../site-packages/lightning/pytorch/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File /home/sharing/.../site-packages/lightning/pytorch/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /home/sharing/.../site-packages/lightning/pytorch/trainer/trainer.py:1025, in Trainer._run_stage(self)
   1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025         self.fit_loop.run()
   1026     return None
   1027 raise RuntimeError(f"Unexpected state {self.state}")

File /home/sharing/.../site-packages/lightning/pytorch/loops/fit_loop.py:204, in _FitLoop.run(self)
    202 while not self.done:
    203     try:
--> 204         self.on_advance_start()
    205         self.advance()
    206         self.on_advance_end()

File /home/sharing/.../site-packages/lightning/pytorch/loops/fit_loop.py:345, in _FitLoop.on_advance_start(self)
    341     _set_sampler_epoch(dl, self.epoch_progress.current.processed)
    343 self.epoch_progress.increment_ready()
--> 345 call._call_callback_hooks(trainer, "on_train_epoch_start")
    346 call._call_lightning_module_hook(trainer, "on_train_epoch_start")
    348 self.epoch_progress.increment_started()

File /home/sharing/.../site-packages/lightning/pytorch/trainer/call.py:218, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, **kwargs)
    216     if callable(fn):
    217         with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
--> 218             fn(trainer, trainer.lightning_module, *args, **kwargs)
    220 if pl_module:
    221     # restore current_fx when nested context
    222     pl_module._current_fx_name = prev_fx_name

File /home/sharing/.../site-packages/finetuning_scheduler/fts.py:815, in FinetuningScheduler.on_train_epoch_start(self, trainer, pl_module)
    813 if self.should_transition(trainer):
    814     self._fts_state._curr_depth += 1  # increment depth
--> 815     self.step()
    816     rank_zero_debug(
    817         f"Current depth is {self.curr_depth}."
    818         "\nCurrent logical parameters thawed by Fine-Tuning Scheduler:\n "
   (...)
    821         f"{pformat(self._fts_state._curr_thawed_params)}. "
    822     )
    823     if not self.epoch_transitions_only:

File /home/sharing/.../site-packages/finetuning_scheduler/fts.py:340, in FinetuningScheduler.step(self)
    338 if not self._fts_state._resume_fit_from_ckpt:
    339     if self.restore_best:
--> 340         self.restore_best_ckpt()
    341         self.step_pg(
    342             depth=self.curr_depth,
    343             optimizer=self.pl_module.trainer.optimizers[0],  # type: ignore[arg-type]
    344             pre_reinit_state=pre_reinit_state,
    345         )
    346     else:

File /home/sharing/.../site-packages/finetuning_scheduler/fts.py:510, in FinetuningScheduler.restore_best_ckpt(self)
    507 self.pl_module.trainer._checkpoint_connector.restore_model()
    508 # we need to override checkpoint_connector.restore_training_state() to bypass loop restoration
    509 # if additional customizations are required, may make sense to subclass _CheckpointConnector at some point
--> 510 self._restore_training_state()
    511 self.pl_module.trainer._checkpoint_connector.resume_end()

File /home/sharing/.../site-packages/finetuning_scheduler/fts.py:533, in FinetuningScheduler._restore_training_state(self)
    531     self.strategy_adapter.on_before_restore_optimizers_and_lrs()
    532     # restore optimizers and schedulers state
--> 533     checkpoint_connector.restore_optimizers_and_schedulers()
    534 except KeyError as ke:
    535     self._maybe_allow_incompatible_reinit_ckpt(ke)

File /home/sharing/.../site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:368, in _CheckpointConnector.restore_optimizers_and_schedulers(self)
    363     if "optimizer_states" not in self._loaded_checkpoint:
    364         raise KeyError(
    365             "Trying to restore optimizer state but checkpoint contains only the model."
    366             " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
    367         )
--> 368     self.restore_optimizers()
    370 if "lr_schedulers" not in self._loaded_checkpoint:
    371     raise KeyError(
    372         "Trying to restore learning rate scheduler state but checkpoint contains only the model."
    373         " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
    374     )

File /home/sharing/.../site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:383, in _CheckpointConnector.restore_optimizers(self)
    380     return
    382 # restore the optimizers
--> 383 self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)

File /home/sharing/.../site-packages/lightning/pytorch/strategies/strategy.py:376, in Strategy.load_optimizer_state_dict(self, checkpoint)
    374 optimizer_states = checkpoint["optimizer_states"]
    375 for optimizer, opt_state in zip(self.optimizers, optimizer_states):
--> 376     optimizer.load_state_dict(opt_state)
    377     _optimizer_to_device(optimizer, self.root_device)

File /home/sharing/.../site-packages/torch/_compile.py:31, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     28     disable_fn = torch._dynamo.disable(fn, recursive)
     29     fn.__dynamo_disable = disable_fn
---> 31 return disable_fn(*args, **kwargs)

File /home/sharing/.../site-packages/torch/_dynamo/eval_frame.py:600, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    598 prior = set_eval_frame(callback)
    599 try:
--> 600     return fn(*args, **kwargs)
    601 finally:
    602     set_eval_frame(prior)

File /home/sharing/.../site-packages/torch/optim/optimizer.py:848, in Optimizer.load_state_dict(self, state_dict)
    845 saved_groups = deepcopy(state_dict["param_groups"])
    847 if len(groups) != len(saved_groups):
--> 848     raise ValueError(
    849         f"loaded state dict has a different number of " "parameter groups"
    850     )
    851 param_lens = (len(g["params"]) for g in groups)
    852 saved_lens = (len(g["params"]) for g in saved_groups)

ValueError: loaded state dict has a different number of parameter groups

To Reproduce

import lightning as L
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
import torch

class BoringModel(L.LightningModule):
    def __init__(self):
        super(BoringModel, self).__init__()
        layers = [nn.Linear(32, 64)]
        for _ in range(20):
            layers.append(nn.Linear(64, 64))
        layers.append(nn.Linear(64, 10))
        self.layers = nn.ModuleList(layers)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        x = self.layers[-1](x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', 1)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer] 

x = torch.randn(800, 32)
y = torch.randint(0, 10, (800,))
train_dataset = TensorDataset(x, y)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

x_val = torch.randn(200, 32)
y_val = torch.randint(0, 10, (200,))
val_dataset = TensorDataset(x_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

model = BoringModel()

from finetuning_scheduler import FinetuningScheduler, FTSEarlyStopping, FTSCheckpoint

import uuid
id = str(uuid.uuid4())

checkpoint = FTSCheckpoint(monitor="val_loss", mode="min", save_last=True, save_top_k=1, verbose=False, dirpath=f"/tmp/{id}")
early_stop_callback = FTSEarlyStopping(monitor="val_loss", mode="min", patience=0, verbose=False)


ft_scheduler = FinetuningScheduler(max_depth=-1)

trainer = L.Trainer(max_epochs=30, callbacks=[early_stop_callback, checkpoint, ft_scheduler], accelerator="cpu")
trainer.fit(model, train_loader, val_loader)

Expected behavior

Should just reload the checkpoint without error.

Environment

  • Fine-Tuning Scheduler Version (e.g., 2.5.0): 2.4.0
  • Lightning Version (e.g., 2.5.0): 2.4.0
  • PyTorch Version (e.g., 2.5.0): 2.4.0
  • Python version (e.g., 3.12): 3.12
  • OS (e.g., Linux): ubuntu
  • CUDA/cuDNN version: 12.0
  • How you installed PyTorch (conda, pip, source): conda
@CyprienRicque CyprienRicque added the bug Something isn't working label Aug 27, 2024
@CyprienRicque CyprienRicque changed the title Optimizer state reload fails if loads is from more than 1 level before Optimizer state reload fails if load is from more than 1 level before Aug 27, 2024
@speediedan
Copy link
Owner

Thanks for noticing the issue and providing the repro @CyprienRicque!

Root Cause

To minimize the dependent parent callback code overridden (ModelCheckpoint in this case), Fine-Tuning Scheduler (FTS) (and FTSCheckpoint) previously checked whether current_score == best_model_score and if so, updated state accordingly just-in-time before returning state_dict calls. In the edge case where current_score precisely equals best_model_score between different depths, best_ckpt_depth and best_ckpt_pgs were set to the latest (deepest) depth with that score.

Unfortunately, (irrespective of save_last=True) in this edge case where the current_score precisely equals the best_model_score at different depths, the depth-aligned checkpoint metadata (best_ckpt_depth and best_ckpt_pgs) can be updated when best_model_path is not (best_model_path ambiguity when there are multiple identical best scores like in these edge cases is not explicitly resolved by ModelCheckpoint but determined by insertion order I believe).

Resolution

To address these edge cases, I've just pushed a commit that:

  1. Adds a context manager (_depth_metadata_lock) that conditions just-in-time mutability of depth-aligned checkpoint metadata
  2. Tracks best_model_path changes rather than relying upon the imperfect current_score == best_model_score proxy.

I've also added a few tests to avoid similar issues/regressions in the future.

I hope it's okay that I slightly rename the issue title to reflect the identified bug scope.

This fix should be available with the next FTS patch release (2.4.1).

Thanks again for identifying this issue and taking the time to report it with a good repro, you've helped improve FTS for everyone!

Feel free to reach out anytime if you have other issues or want to share more about your use case. Best of luck with your work!

@speediedan speediedan changed the title Optimizer state reload fails if load is from more than 1 level before Identical best monitored metric values across different depths can result in depth-aligned checkpoint metadata corruption Aug 28, 2024
@speediedan
Copy link
Owner

fixed with 0ef51de

@CyprienRicque
Copy link
Author

Great thank you ! 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants
@speediedan @CyprienRicque and others