Skip to content

Commit

Permalink
move load state dict after initialize parallel state
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Li committed Jun 7, 2024
1 parent f1062b7 commit 791a037
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ def load_from_checkpoint(

# if the checkpoint is distributed, we deferred loading the state_dict until now
if checkpoint_dir is not None:
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# dist checkpointing needs torch.distributed to load the checkpoint
if not parallel_state.is_initialized():

Expand All @@ -398,6 +396,8 @@ def dummy():
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# load the checkpoint from disk
checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
# restore the weights
Expand Down

0 comments on commit 791a037

Please sign in to comment.