Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V23.05 Megatron-LM with mcr-dl support #1

Merged
merged 3 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions examples/detxoify_lm/generate_samples_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
import torch
import mcr_dl
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
Expand Down Expand Up @@ -57,15 +58,16 @@ def add_text_generate_args(parser):

def generate_samples_unconditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()

if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
cnt = 0
num_samples = args.num_samples
from tqdm import tqdm
pbar = tqdm(total=num_samples)

while True:
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
sentences = [''] * args.global_batch_size
print("global batch size", args.global_batch_size)
max_len = args.out_seq_length
Expand Down Expand Up @@ -94,8 +96,9 @@ def generate_samples_unconditional(model):

def generate_samples_conditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()

if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
num_samples = args.num_samples
cnt = 0
from tqdm import tqdm
Expand All @@ -108,8 +111,8 @@ def generate_samples_conditional(model):
input_pos = 0

while True:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
dist.barrier()
if dist.get_rank() == 0:
sentences = []
print("global batch size", args.global_batch_size)
for _ in range(args.global_batch_size):
Expand Down Expand Up @@ -147,15 +150,17 @@ def generate_samples_conditional(model):

def generate_and_write_samples_unconditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model):
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
f.write(json.dumps(datum) + '\n')


def generate_and_write_samples_conditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
Expand All @@ -164,7 +169,7 @@ def generate_and_write_samples_conditional(model):
sample_output_file = args.sample_output_file
with open(sample_output_file, 'w') as f:
for datum in generate_samples_conditional(model):
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
f.write(json.dumps(datum) + '\n')


Expand Down
7 changes: 6 additions & 1 deletion examples/pretrain_gpt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ OUTPUT_ARGS="
--eval-interval 1000 \
--eval-iters 10
"
MCRDL_ARGS="
--distributed-engine torch \
--distributed-backend nccl \
"

torchrun pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$MCRDL_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
8 changes: 7 additions & 1 deletion examples/pretrain_gpt_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ DISTRIBUTED_ARGS="
--master_port $MASTER_PORT
"

MCRDL_ARGS="
--distributed-engine torch \
--distributed-backend nccl \
"

GPT_ARGS="
--num-layers 24 \
--hidden-size 1024 \
Expand Down Expand Up @@ -63,6 +68,7 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
$MCRDL_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
11 changes: 10 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
parser = _add_mcr_dl_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand Down Expand Up @@ -942,7 +943,7 @@ def _add_distributed_args(parser):
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
choices=['nccl', 'mpi', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
Expand Down Expand Up @@ -1226,3 +1227,11 @@ def _add_vision_args(parser):
help='warmup teacher temperaure epochs')

return parser

def _add_mcr_dl_args(parser):
group = parser.add_argument_group(title='experimental')
group.add_argument("--distributed-engine", type=str, default='torch',
choices=['mcr_dl', 'torch'],
help='Distributed DL framework to use')

return parser
39 changes: 22 additions & 17 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

import torch
import mcr_dl

from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
Expand Down Expand Up @@ -164,9 +165,10 @@ def read_metadata(tracker_filename):
tracker_filename)

# Get the max iteration retrieved across the ranks.
if torch.distributed.is_initialized():
dist = mcr_dl.get_distributed_engine()
if dist.is_initialized():
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
dist.all_reduce(iters_cuda, op=dist.ReduceOp.MAX)
max_iter = iters_cuda[0].item()

# We should now have all the same iteration.
Expand Down Expand Up @@ -196,12 +198,13 @@ def get_rng_state():
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}

rng_state_list = None
if torch.distributed.is_initialized() and \
dist = mcr_dl.get_distributed_engine()
if dist.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
dist.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
Expand Down Expand Up @@ -235,7 +238,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optimizer.save_parameter_state(optim_checkpoint_name)

# Collect args, model, RNG.
if not torch.distributed.is_initialized() \
dist = mcr_dl.get_distributed_engine()
if not dist.is_initialized() \
or mpu.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
Expand Down Expand Up @@ -268,22 +272,22 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
torch.save(state_dict, checkpoint_name)

# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_initialized():
dist.barrier()

print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
.format(iteration, args.save))

# And update the latest iteration
if not torch.distributed.is_initialized() \
or torch.distributed.get_rank() == 0:
if not dist.is_initialized() \
or dist.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))

# Wait so everyone is done (not necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_initialized():
dist.barrier()


def _transpose_first_dim(t, num_splits, num_splits_first, model):
Expand Down Expand Up @@ -509,7 +513,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
dist = mcr_dl.get_distributed_engine()
dist.barrier()
sys.exit()

# Iteration defaults to 0.
Expand Down Expand Up @@ -631,8 +636,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit()

# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_initialized():
dist.barrier()

print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
Expand Down Expand Up @@ -660,10 +665,10 @@ def load_biencoder_checkpoint(model, only_query_model=False,
checkpoint_name = get_checkpoint_name(load_path, iteration,
args.use_distributed_optimizer,
release=False)

dist = mcr_dl.get_distributed_engine()
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
dist.get_rank(), checkpoint_name))

state_dict = torch.load(model_checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
Expand All @@ -675,7 +680,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,

assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
dist.barrier()

if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
Expand Down
Loading