Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error learning rate when load ckpt for cotinue training if check_val_every_n_epoch > 1 #20495

Open
razgzy opened this issue Dec 13, 2024 · 3 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@razgzy
Copy link

razgzy commented Dec 13, 2024

Bug description

I run with LightningCLI. Set check_val_every_n_epoch > 1 (e.g. 2) to run an experiment with 20 max_epoches, the model ckpt is save by lightning.pytorch.callbacks.ModelCheckpoint. The learning rate schedular is torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.trainer.max_epochs, T_mult=1, eta_min=self.eta_min) and update each epoch. When I load a ckpt (e.g. saved at epoch 3) to continue training, the learning rate will update 1 epoch quicker than expected.
image
Here the red is original learning rate curve and the yellow is the continued. The lr is logged by lightning.pytorch.callbacks.LearningRateMonitor.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

the trainer setting
seed_everything: 0
trainer:
  accelerator: gpu
  strategy: auto
  devices:
  - 0
  num_nodes: 1
  precision: 32-true
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: runs
      name: name
      version: null
      log_graph: false
      default_hp_metric: true
      prefix: ''
      sub_dir: test
  callbacks:
  - class_path: utils.log_manager.LogManager
  - class_path: lightning.pytorch.callbacks.LearningRateMonitor
    init_args: 
      logging_interval: epoch
  - class_path: lightning.pytorch.callbacks.ModelCheckpoint
    init_args:
      dirpath: null
      filename: 'epoch={epoch}-psnr={hp_metric:.4f}'
      monitor: hp_metric
      verbose: false
      save_last: true
      save_top_k: 5
      save_weights_only: false
      mode: max
      auto_insert_metric_name: false
      every_n_train_steps: null
      train_time_interval: null
      every_n_epochs: 1
      save_on_train_epoch_end: null
  fast_dev_run: false
  max_epochs: 20
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 2
  num_sanity_val_steps: null
  log_every_n_steps: 20
  enable_checkpointing: null
  enable_progress_bar: null
  enable_model_summary: null
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: null
  benchmark: null
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: false
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null


 def configure_optimizers(self):
        params = []
        params.append({'params': self.network.parameters(), 'lr': self.lr, 'weight_decay': self.weight_decay, 'name': 'network'})
        optimizer = torch.optim.AdamW(params)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.trainer.max_epochs, T_mult=1, eta_min=self.eta_min)
        lr_scheduler_config = {
            "scheduler": lr_scheduler,
            "interval": 'epoch',
            "name": 'AdamW'
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}


### Error messages and logs

Error messages and logs here please



### Environment

<details>
  <summary>Current environment</summary>

#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):


</details>


### More info

_No response_
@razgzy razgzy added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 13, 2024
@razgzy
Copy link
Author

razgzy commented Dec 13, 2024

I found that the self._restarting=False if check_val_every_n_epoch=1 when I load from ckpt but self._restarting=True if check_val_every_n_epoch>1 in self.fit_loop.run().

@lantiga
Copy link
Collaborator

lantiga commented Dec 17, 2024

thanks for the filing the issue @razgzy

can you verify if this is the case with the latest master as well?

also it would be great if you could share a small repro that works end to end, so it’s quicker for me to just run and debug

@razgzy
Copy link
Author

razgzy commented Dec 19, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

2 participants