diff --git a/examples/nlp/language_modeling/megatron_change_num_partitions.py b/examples/nlp/language_modeling/megatron_change_num_partitions.py index a4b28fa4d761..558986e3da36 100644 --- a/examples/nlp/language_modeling/megatron_change_num_partitions.py +++ b/examples/nlp/language_modeling/megatron_change_num_partitions.py @@ -125,7 +125,7 @@ def set_virtual_parallel_rank_safely(rank: int): parallel_state.set_virtual_pipeline_model_parallel_rank(rank) if rank is None: - parallel_state.set_virtual_pipeline_model_parallel_world_size(0) + parallel_state.set_virtual_pipeline_model_parallel_world_size(None) except (ImportError, ModuleNotFoundError): logging.warning("`megatron-core` not installed, cannot set virtual parallel rank !") @@ -861,6 +861,10 @@ def main(): convert_vp = vp_size > 1 if convert_vp: + from megatron.core import parallel_state + + parallel_state.set_virtual_pipeline_model_parallel_world_size(vp_size) + hparams_filepath = args.hparams_file if hparams_filepath is None: logging.warning(