diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 4e6e745cf6b..7dd8ad0d564 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -33,6 +33,7 @@ import os import subprocess + def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -41,9 +42,10 @@ def model_provider(pre_process=True, post_process=True): args = get_args() with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), - remote_device=None if args.remote_device=='none' else args.remote_device, - config=args.deepspeed_config, - enabled=args.zero_stage==3): + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): if args.deepspeed: model = GPTModelPipe( num_tokentypes=0, @@ -59,14 +61,14 @@ def model_provider(pre_process=True, post_process=True): attention_mask = torch.tril(torch.ones( (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view( 1, 1, args.seq_length, args.seq_length) - + # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) if args.fp16: attention_mask = attention_mask.half() elif args.bf16: attention_mask = attention_mask.bfloat16() - + args.attn_mask = attention_mask else: @@ -111,6 +113,7 @@ def get_batch(data_iterator): return tokens, labels, loss_mask, attention_mask, position_ids + def get_batch_pipe(data): """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" args = get_args() @@ -138,6 +141,7 @@ def get_batch_pipe(data): return (tokens, position_ids, attention_mask), (labels, loss_mask) + def loss_func(loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() @@ -184,10 +188,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): return train_ds, valid_ds, test_ds + def command_exists(cmd): result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) return result.wait() == 0 + def git_ds_info(): from deepspeed.env_report import main as ds_report ds_report()