Skip to content

Commit

Permalink
Merge pull request #1 from OSU-Nowlab/v23.05_MCR-DL_support
Browse files Browse the repository at this point in the history
V23.05 Megatron-LM with mcr-dl support
  • Loading branch information
Quentin-Anthony authored Apr 21, 2024
2 parents 0604155 + c0977b5 commit f34132d
Show file tree
Hide file tree
Showing 67 changed files with 749 additions and 470 deletions.
19 changes: 12 additions & 7 deletions examples/detxoify_lm/generate_samples_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
import torch
import mcr_dl
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
Expand Down Expand Up @@ -57,15 +58,16 @@ def add_text_generate_args(parser):

def generate_samples_unconditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()

if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
cnt = 0
num_samples = args.num_samples
from tqdm import tqdm
pbar = tqdm(total=num_samples)

while True:
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
sentences = [''] * args.global_batch_size
print("global batch size", args.global_batch_size)
max_len = args.out_seq_length
Expand Down Expand Up @@ -94,8 +96,9 @@ def generate_samples_unconditional(model):

def generate_samples_conditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()

if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
num_samples = args.num_samples
cnt = 0
from tqdm import tqdm
Expand All @@ -108,8 +111,8 @@ def generate_samples_conditional(model):
input_pos = 0

while True:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
dist.barrier()
if dist.get_rank() == 0:
sentences = []
print("global batch size", args.global_batch_size)
for _ in range(args.global_batch_size):
Expand Down Expand Up @@ -147,15 +150,17 @@ def generate_samples_conditional(model):

def generate_and_write_samples_unconditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model):
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
f.write(json.dumps(datum) + '\n')


def generate_and_write_samples_conditional(model):
args = get_args()
dist = mcr_dl.get_distributed_engine()
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
Expand All @@ -164,7 +169,7 @@ def generate_and_write_samples_conditional(model):
sample_output_file = args.sample_output_file
with open(sample_output_file, 'w') as f:
for datum in generate_samples_conditional(model):
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
f.write(json.dumps(datum) + '\n')


Expand Down
7 changes: 6 additions & 1 deletion examples/pretrain_gpt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ OUTPUT_ARGS="
--eval-interval 1000 \
--eval-iters 10
"
MCRDL_ARGS="
--distributed-engine torch \
--distributed-backend nccl \
"

torchrun pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$MCRDL_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
8 changes: 7 additions & 1 deletion examples/pretrain_gpt_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ DISTRIBUTED_ARGS="
--master_port $MASTER_PORT
"

MCRDL_ARGS="
--distributed-engine torch \
--distributed-backend nccl \
"

GPT_ARGS="
--num-layers 24 \
--hidden-size 1024 \
Expand Down Expand Up @@ -63,6 +68,7 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
$MCRDL_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
11 changes: 10 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
parser = _add_mcr_dl_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand Down Expand Up @@ -942,7 +943,7 @@ def _add_distributed_args(parser):
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
choices=['nccl', 'mpi', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
Expand Down Expand Up @@ -1226,3 +1227,11 @@ def _add_vision_args(parser):
help='warmup teacher temperaure epochs')

return parser

def _add_mcr_dl_args(parser):
group = parser.add_argument_group(title='experimental')
group.add_argument("--distributed-engine", type=str, default='torch',
choices=['mcr_dl', 'torch'],
help='Distributed DL framework to use')

return parser
39 changes: 22 additions & 17 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

import torch
import mcr_dl

from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
Expand Down Expand Up @@ -164,9 +165,10 @@ def read_metadata(tracker_filename):
tracker_filename)

# Get the max iteration retrieved across the ranks.
if torch.distributed.is_initialized():
dist = mcr_dl.get_distributed_engine()
if dist.is_initialized():
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
dist.all_reduce(iters_cuda, op=dist.ReduceOp.MAX)
max_iter = iters_cuda[0].item()

# We should now have all the same iteration.
Expand Down Expand Up @@ -196,12 +198,13 @@ def get_rng_state():
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}

rng_state_list = None
if torch.distributed.is_initialized() and \
dist = mcr_dl.get_distributed_engine()
if dist.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
dist.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
Expand Down Expand Up @@ -235,7 +238,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optimizer.save_parameter_state(optim_checkpoint_name)

# Collect args, model, RNG.
if not torch.distributed.is_initialized() \
dist = mcr_dl.get_distributed_engine()
if not dist.is_initialized() \
or mpu.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
Expand Down Expand Up @@ -268,22 +272,22 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
torch.save(state_dict, checkpoint_name)

# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_initialized():
dist.barrier()

print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
.format(iteration, args.save))

# And update the latest iteration
if not torch.distributed.is_initialized() \
or torch.distributed.get_rank() == 0:
if not dist.is_initialized() \
or dist.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))

# Wait so everyone is done (not necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_initialized():
dist.barrier()


def _transpose_first_dim(t, num_splits, num_splits_first, model):
Expand Down Expand Up @@ -509,7 +513,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
dist = mcr_dl.get_distributed_engine()
dist.barrier()
sys.exit()

# Iteration defaults to 0.
Expand Down Expand Up @@ -631,8 +636,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit()

# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_initialized():
dist.barrier()

print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
Expand Down Expand Up @@ -660,10 +665,10 @@ def load_biencoder_checkpoint(model, only_query_model=False,
checkpoint_name = get_checkpoint_name(load_path, iteration,
args.use_distributed_optimizer,
release=False)

dist = mcr_dl.get_distributed_engine()
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
dist.get_rank(), checkpoint_name))

state_dict = torch.load(model_checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
Expand All @@ -675,7 +680,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,

assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
dist.barrier()

if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
Expand Down
Loading

0 comments on commit f34132d

Please sign in to comment.