How do I load a model checkpointed with ddp_spawn? #5728
Unanswered
laughingrice
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am using pytorch_lightning.callbacks.ModelCheckpoint(save_weights_only=True) to checkpoint my trained model
Ignoring the problem for the moment that my checkpoint files are huge (6.2GB for a model with 2.7M trainable parameters), it seems that the model is checkpointed in place.
I am training on a system with 4 GPUs. In the simple case if I train the model passing 'accelerator' = 'dp' to the Trainer object, checkpoint happens with device='cuda:0', I can deal with that by passing map_location='cpu' to torch.load, inconvenient but manageable.
If I don't set 'accelerator' = 'dp' however it looks that ddp_spawn is chosen and I cannot figure out how to reload the checkpointed model. I get the error -- AssertionError: Default process group is not initialized, regardless of the map_location value. I found one post discussing setting up the environment, that I couldn't no figure out, but I am looking to just save model weights without all the location/environment inference so that it can be generically loaded on any device without much headache.
How do I tell pytorchlightning / ModelCheckpoint to transfer the model to the CPU before saving to facilitate easier loading? (and/or use mlflow to checkpoint the model)
Beta Was this translation helpful? Give feedback.
All reactions