diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 37195f1df142..2380ed15cc45 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -387,8 +387,6 @@ def load_from_checkpoint( # if the checkpoint is distributed, we deferred loading the state_dict until now if checkpoint_dir is not None: - sharded_state_dict = model.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict # dist checkpointing needs torch.distributed to load the checkpoint if not parallel_state.is_initialized(): @@ -398,6 +396,8 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + sharded_state_dict = model.sharded_state_dict() + checkpoint['state_dict'] = sharded_state_dict # load the checkpoint from disk checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir) # restore the weights diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 8ca010e59f70..aa21f85c5b1e 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -943,8 +943,6 @@ def save_to(self, model, save_path: str): if dist_ckpt: # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) - - sharded_state_dict = model.sharded_state_dict() # dist checkpoint needs torch.distributed to save the checkpoint if not parallel_state.is_initialized(): @@ -954,6 +952,7 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + sharded_state_dict = model.sharded_state_dict() checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr')) checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)