Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to resume from checkpoint when using apex #11488

Closed
dtmoodie opened this issue Jan 15, 2022 · 10 comments · Fixed by #14341
Closed

Unable to resume from checkpoint when using apex #11488

dtmoodie opened this issue Jan 15, 2022 · 10 comments · Fixed by #14341
Assignees
Labels
bug Something isn't working precision: apex (removed) NVIDIA/apex precision
Milestone

Comments

@dtmoodie
Copy link

dtmoodie commented Jan 15, 2022

🐛 Bug

When trying to resume a model that was trained with apex, I cannot load the checkpoint.

To Reproduce

Train model with trainer.fit with the following params:

precision: 16
amp_level: O2
amp_backend: apex

Then attempt to continue training using the checkpoint and ckpt_path.
The error that I get is:

File "/home/dan/code/ml/yolo/train_lightning.py", line 299, in <module>
    trainer.fit(model, dataset, ckpt_path=manager.args.resume)
  File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
   return trainer_fn(*args, **kwargs)
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
   self._run(model, ckpt_path=ckpt_path)
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1189, in _run
   self.checkpoint_connector.restore_training_state()
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 185, in restore_training_state
   self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/apex_amp.py", line 97, in on_load_checkpoint
   amp.load_state_dict(checkpoint["amp_scaling_state"])
 File "/opt/conda/lib/python3.8/site-packages/apex/amp/frontend.py", line 375, in load_state_dict
   if len(state_dict) != len(_amp_state.loss_scalers):

AttributeError: 'AmpState' object has no attribute 'loss_scalers'

Expected behavior

Training resumes as it would without apex.

Environment

Please copy and paste the output from our environment collection script:

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3090
    - available: True
    - version: 11.4
  • Packages:
    - numpy: 1.21.2
    - pyTorch_debug: False
    - pyTorch_version: 1.10.0a0+0aef44c
    - pytorch-lightning: 1.5.6
    - tqdm: 4.62.3
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.12
    - version: Extend CI #44~20.04.2-Ubuntu SMP Tue Oct 26 18:07:44 UTC 2021
  • How you installed PyTorch (conda, pip, source): pip

cc @carmocca @justusschock @awaelchli @akihironitta @rohitgr7

@dtmoodie dtmoodie added the bug Something isn't working label Jan 15, 2022
@ananthsub ananthsub added the precision: apex (removed) NVIDIA/apex precision label Jan 15, 2022
@awaelchli
Copy link
Contributor

awaelchli commented Jan 16, 2022

@dtmoodie thanks for the report. Would you be interested in investigating this issue?
The code where we dump the amp state to the checkpoint is here: https://github.com/PyTorchLightning/pytorch-lightning/blob/18bbb39eef77c8c3d19789b1067164e962640c93/pytorch_lightning/plugins/precision/apex_amp.py#L99-L100
and the loading part is right there too. I would start investigating the missing key there.

@ananthsub
Copy link
Contributor

@awaelchli i believe it's that amp initialize isn't called on the module before loading

@awaelchli
Copy link
Contributor

That's a good call. As part of #10416 we should anyway look into this call here where amp.initialize happens:
https://github.com/PyTorchLightning/pytorch-lightning/blob/f9c3619eeb435ccaf4eca4f94130c795b30ed222/pytorch_lightning/trainer/trainer.py#L1266

We should check whether we can move the amp.initialize from here
https://github.com/PyTorchLightning/pytorch-lightning/blob/f9c3619eeb435ccaf4eca4f94130c795b30ed222/pytorch_lightning/plugins/precision/apex_amp.py#L48-L55

to ealier, e.g., setup(). And then get rid of dispatch(), which is currently only implemented by apex amp.
Moving it to setup would then be early enough before the checkpoint gets loaded.

cc @four4fish

@awaelchli awaelchli added this to the 1.5.x milestone Feb 13, 2022
@carmocca carmocca moved this to Todo in Frameworks Planning Mar 1, 2022
@Borda Borda modified the milestones: 1.5.x, 1.6 Mar 21, 2022
@carmocca carmocca modified the milestones: 1.6, 1.6.x Mar 28, 2022
@carmocca carmocca modified the milestones: pl:1.6.x, pl:1.7.x Jul 28, 2022
@awaelchli
Copy link
Contributor

I tried to pick this up again with no success.

After #11952, the optimizers in DDP get setup after the model has been wrapped (to understand why, read description here #11886). With respect to the plugin setup, the order is the following:

  1. model_to_device()
  2. setup_precision_plugin() (this is where amp.initialize happens)
  3. wrap model in DistributedDataParallel
  4. call LightningModule.configure_optimizers and set up optimizers

Observation: In order to support reloading the amp state, we need to make sure we call amp.initialize before the state gets reloaded. Currently, we initialize too late (just before the loops start). Conclusion: amp.initialize needs to be called earlier

In order to resolve this issue, we need to satisfy these requirements (dictated by apex):

Requirement 1: amp.initialize can't be called too early, it has be called AFTER the model has been moved to the device
Requirement 2: amp.initialize can't be called too late, it has to be called BEFORE the model gets wrapped in DistributedDataParallel
Requirement 3: amp.initialize must be called after the optimizers have been set up

These three requirements contradict the current order in which things are setup (1-4 above). The amp.initialize call can't be inserted anywhere between 1-4 without breaking one of these requirements.

I don't see how apex can be supported in our ddp strategies at the moment, without changing the place where optimizers get set up.

@awaelchli awaelchli moved this from Todo to Blocked in Frameworks Planning Aug 12, 2022
@carmocca
Copy link
Contributor

(Comment from offline discussion)

Let's just be happy for now with allowing loading checkpoints trained with apex enabled but not reloading the apex state. We would print a warning in this case. This can still be useful for further training or inference.

@awaelchli
Copy link
Contributor

We can unfortunately not simply ignore reloading the amp state and still continue training with apex.

E                   AttributeError: 'SGD' object has no attribute '_amp_stash'

https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=91083&view=logs&j=3f274fac-2e11-54ca-487e-194c91f3ae9f&t=fd22e4d0-068c-5651-f467-789f79e7bb7b

The logic would either have to be more involved (in the reloading logic itself outside the plugin), or a hard runtime error.

@carmocca
Copy link
Contributor

IMO that's a bug in APEX. It's been already reported in NVIDIA/apex#1057 with no answer.

Repository owner moved this from Blocked to Done in Frameworks Planning Aug 26, 2022
@carmocca
Copy link
Contributor

Technically this issue isn't resolved as users are still "unable to resume from checkpoint when using apex". We just improved the UX but it's still blocked by Apex anyways.

@treacker
Copy link

treacker commented Dec 6, 2022

Are there any updates on this issue?

@awaelchli
Copy link
Contributor

@treacker We unfortunately can no longer support Apex and are going to deprecate it #14416.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working precision: apex (removed) NVIDIA/apex precision
Projects
No open projects
Status: Done
6 participants