Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move load state dict after initialize parallel state in nlp_model #9382

Merged
merged 4 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ def load_from_checkpoint(

# if the checkpoint is distributed, we deferred loading the state_dict until now
if checkpoint_dir is not None:
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# dist checkpointing needs torch.distributed to load the checkpoint
if not parallel_state.is_initialized():

Expand All @@ -398,6 +396,8 @@ def dummy():
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# load the checkpoint from disk
checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
# restore the weights
Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,6 @@ def save_to(self, model, save_path: str):
if dist_ckpt:
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))

sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if not parallel_state.is_initialized():

Expand All @@ -954,6 +952,7 @@ def dummy():
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)

Expand Down
Loading