From 791a03733119170032c43ae2c57b88d5ca35176e Mon Sep 17 00:00:00 2001 From: Ryan Li Date: Wed, 5 Jun 2024 01:00:16 +0000 Subject: [PATCH] move load state dict after initialize parallel state --- nemo/collections/nlp/models/nlp_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 37195f1df142f..2380ed15cc45c 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -387,8 +387,6 @@ def load_from_checkpoint( # if the checkpoint is distributed, we deferred loading the state_dict until now if checkpoint_dir is not None: - sharded_state_dict = model.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict # dist checkpointing needs torch.distributed to load the checkpoint if not parallel_state.is_initialized(): @@ -398,6 +396,8 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + sharded_state_dict = model.sharded_state_dict() + checkpoint['state_dict'] = sharded_state_dict # load the checkpoint from disk checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir) # restore the weights