From 61fbec4235e1f5ab1e8149539aa6277a2b9776ac Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 14 Jun 2024 18:08:23 -0700 Subject: [PATCH] move load state dict after initialize parallel state in nlp_model (#9382) * move load state dict after initialize parallel state Signed-off-by: Ryan Li * delay sharded_state_dict in save_to Signed-off-by: Ryan Li --------- Signed-off-by: Ryan Li Co-authored-by: Ryan Li --- nemo/collections/nlp/models/nlp_model.py | 4 ++-- nemo/collections/nlp/parts/nlp_overrides.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 37195f1df142..2380ed15cc45 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -387,8 +387,6 @@ def load_from_checkpoint( # if the checkpoint is distributed, we deferred loading the state_dict until now if checkpoint_dir is not None: - sharded_state_dict = model.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict # dist checkpointing needs torch.distributed to load the checkpoint if not parallel_state.is_initialized(): @@ -398,6 +396,8 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + sharded_state_dict = model.sharded_state_dict() + checkpoint['state_dict'] = sharded_state_dict # load the checkpoint from disk checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir) # restore the weights diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 6b356539aba9..0555776457a5 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -948,8 +948,6 @@ def save_to(self, model, save_path: str): if dist_ckpt: # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) - - sharded_state_dict = model.sharded_state_dict() # dist checkpoint needs torch.distributed to save the checkpoint if not parallel_state.is_initialized(): @@ -959,6 +957,7 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + sharded_state_dict = model.sharded_state_dict() checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr')) checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)