diff --git a/BingBertSquad/nvidia_run_squad_deepspeed.py b/BingBertSquad/nvidia_run_squad_deepspeed.py index 454d2c7c9..8ab2949e3 100755 --- a/BingBertSquad/nvidia_run_squad_deepspeed.py +++ b/BingBertSquad/nvidia_run_squad_deepspeed.py @@ -741,7 +741,7 @@ def set_optimizer_params_grad(named_params_optimizer, def main(): parser = get_argument_parser() - torch.distributed.init_process_group(backend='nccl') + deepspeed.init_distributed(dist_backend='nccl') # Include DeepSpeed configuration arguments parser = deepspeed.add_config_arguments(parser) diff --git a/Megatron-LM/pretrain_gpt2.py b/Megatron-LM/pretrain_gpt2.py index 9a0153fbe..14cf0d28c 100755 --- a/Megatron-LM/pretrain_gpt2.py +++ b/Megatron-LM/pretrain_gpt2.py @@ -549,20 +549,24 @@ def set_deepspeed_activation_checkpointing(args): def initialize_distributed(args): """Initialize torch.distributed.""" - # Manually set the device ids. - device = args.rank % torch.cuda.device_count() + if args.deepspeed: + deepspeed.init_distributed(dist_backend=args.distributed_backend) + else: + # Manually set the device ids. + device = args.rank % torch.cuda.device_count() + # Call the init process + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, rank=args.rank, + init_method=init_method) + if args.local_rank is not None: device = args.local_rank torch.cuda.set_device(device) - # Call the init process - init_method = 'tcp://' - master_ip = os.getenv('MASTER_ADDR', 'localhost') - master_port = os.getenv('MASTER_PORT', '6000') - init_method += master_ip + ':' + master_port - torch.distributed.init_process_group( - backend=args.distributed_backend, - world_size=args.world_size, rank=args.rank, - init_method=init_method) # Set the model-parallel / data-parallel communicators. mpu.initialize_model_parallel(args.model_parallel_size) diff --git a/bing_bert/deepspeed_train.py b/bing_bert/deepspeed_train.py index ce25913a9..dc54da6ee 100755 --- a/bing_bert/deepspeed_train.py +++ b/bing_bert/deepspeed_train.py @@ -390,7 +390,7 @@ def prepare_optimizer_parameters(args, model): def prepare_model_optimizer(args): # Initialize torch distributed - # torch.distributed.init_process_group(backend="nccl") + deepspeed.init_distributed(dist_backend='nccl') # Loading Model model = BertMultiTask(args) diff --git a/pipeline_parallelism/train.py b/pipeline_parallelism/train.py index d4197e55d..8369b8785 100755 --- a/pipeline_parallelism/train.py +++ b/pipeline_parallelism/train.py @@ -149,7 +149,7 @@ def train_pipe(args, part='parameters'): args = get_args() torch.cuda.set_device(args.local_rank) - dist.init_process_group(backend=args.backend) + deepspeed.init_distributed(dist_backend=args.backend) if args.pipeline_parallel_size == 0: train_base(args)