Skip to content

Commit

Permalink
move load state dict after initialize parallel state in nlp_model (NV…
Browse files Browse the repository at this point in the history
…IDIA#9382)

* move load state dict after initialize parallel state

Signed-off-by: Ryan Li <rynli@amazon.com>

* delay sharded_state_dict in save_to

Signed-off-by: Ryan Li <rynli@amazon.com>

---------

Signed-off-by: Ryan Li <rynli@amazon.com>
Co-authored-by: Ryan Li <rynli@amazon.com>
  • Loading branch information
2 people authored and JesusPaz committed Jun 18, 2024
1 parent ff1e849 commit 61fbec4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ def load_from_checkpoint(

# if the checkpoint is distributed, we deferred loading the state_dict until now
if checkpoint_dir is not None:
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# dist checkpointing needs torch.distributed to load the checkpoint
if not parallel_state.is_initialized():

Expand All @@ -398,6 +396,8 @@ def dummy():
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# load the checkpoint from disk
checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
# restore the weights
Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,6 @@ def save_to(self, model, save_path: str):
if dist_ckpt:
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))

sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if not parallel_state.is_initialized():

Expand All @@ -959,6 +957,7 @@ def dummy():
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)

Expand Down

0 comments on commit 61fbec4

Please sign in to comment.