Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rampup batch size support for Megatron GPT #6424

Merged
merged 18 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ model:
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 4 # limited by GPU memory
global_batch_size: 8 # will use more micro batches to reach global batch size
rampup_batch_size: null # Should be a list of 3 values: [<start_batch_size>, <batch_size_increment>, <rampup_samples>]
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism
virtual_pipeline_model_parallel_size: null # interleaved pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size'),
use_fp8=cfg.get('fp8', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from nemo.utils import logging

try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True
Expand Down Expand Up @@ -427,15 +428,23 @@ def training_step(self, dataloader_iter, batch_idx):
'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1,
)

consumed_samples = self.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
prog_bar=True,
rank_zero_only=True,
batch_size=1,
'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1,
)

if self.cfg.get('rampup_batch_size', None):
micro_batch_size = self.cfg.get('micro_batch_size', 1)
total_gpus_number = self.trainer.num_devices * self.trainer.num_nodes
current_global_batch_size = get_num_microbatches() * micro_batch_size * total_gpus_number
self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1)

num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
num_microbatch_calculator.update(
consumed_samples=consumed_samples, consistency_check=True,
)

return loss_mean

def backward(self, *args, **kwargs):
Expand Down Expand Up @@ -815,6 +824,29 @@ def setup(self, stage=None):
self.init_consumed_samples = init_consumed_samples
self.init_global_step = self.trainer.global_step

rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if rampup_batch_size:
start_batch_size = rampup_batch_size[0]
batch_size_increment = rampup_batch_size[1]
total_gpus_number = self.trainer.num_devices * self.trainer.num_nodes

assert start_batch_size % (total_gpus_number) == 0, (
'expected'
' start batch size ({}) to be divisible by total number of GPUs'
' ({})'.format(start_batch_size, total_gpus_number)
)

micro_batch_size = self.cfg.get('micro_batch_size', 1)
tensor_model_parallel_size = self.cfg.get('tensor_model_parallel_size', 1)
pipeline_model_parallel_size = self.cfg.get('pipeline_model_parallel_size', 1)
total_data_parallel_size = total_gpus_number // (tensor_model_parallel_size * pipeline_model_parallel_size)

assert batch_size_increment % (micro_batch_size * total_data_parallel_size) == 0, (
'expected'
' batch size increment ({}) to be divisible by micro_batch_size ({}) times total data parallel size'
' ({})'.format(batch_size_increment, micro_batch_size, total_data_parallel_size)
)

if stage == 'predict':
return
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank=None,
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
use_fp8=False,
seed=1234,
apex_transformer_log_level=30,
Expand Down Expand Up @@ -121,7 +122,7 @@ def initialize_model_parallel_for_nemo(
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=None,
rampup_batch_size=rampup_batch_size,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches):
Expand Down
195 changes: 195 additions & 0 deletions tests/collections/nlp/test_rampup_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer


from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy

try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True

Check notice

Code scanning / CodeQL

Unused global variable

The global variable 'HAVE_APEX' is not used.

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

Check notice

Code scanning / CodeQL

Unused global variable

The global variable 'HAVE_APEX' is not used.

DEVICE_CAPABILITY = None
if torch.cuda.is_available():
DEVICE_CAPABILITY = torch.cuda.get_device_capability()


def reset_microbatch_calculator():
apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None


@pytest.fixture()
def model_cfg(test_data_dir):

model_cfg = {
'precision': 16,
'micro_batch_size': 4,
'global_batch_size': 16,
'rampup_batch_size': [4, 4, 100],
'tensor_model_parallel_size': 1,
'pipeline_model_parallel_size': 1,
'resume_from_checkpoint': None,
'encoder_seq_length': 512,
'max_position_embeddings': 512,
'num_layers': 1,
'hidden_size': 128,
'ffn_hidden_size': 512,
'num_attention_heads': 2,
'init_method_std': 0.02,
'hidden_dropout': 0.1,
'kv_channels': None,
'apply_query_key_layer_scaling': True,
'layernorm_epsilon': 1e-5,
'make_vocab_size_divisible_by': 128,
'pre_process': True,
'post_process': True,
'persist_layer_norm': True,
'gradient_as_bucket_view': True,
'tokenizer': {
'library': 'megatron',
'type': 'GPT2BPETokenizer',
'model': None,
'vocab_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/vocab.json'),
'merge_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/merges.txt'),
'delimiter': None,
},
'native_amp_init_scale': 4294967296,
'native_amp_growth_interval': 1000,
'hysteresis': 2,
'fp32_residual_connection': False,
'fp16_lm_cross_entropy': False,
'megatron_amp_O2': False,
'seed': 1234,
'use_cpu_initialization': False,
'onnx_safe': False,
'apex_transformer_log_level': 30,
'activations_checkpoint_method': None,
'activations_checkpoint_num_layers': 1,
'data': {
'data_prefix': '???',
'index_mapping_dir': None,
'data_impl': 'mmap',
'splits_string': '900,50,50',
'seq_length': 512,
'skip_warmup': True,
'num_workers': 2,
'dataloader_type': 'single',
'reset_position_ids': False,
'reset_attention_mask': False,
'eod_mask_loss': False,
},
'optim': {
'name': 'fused_adam',
'lr': 2e-4,
'weight_decay': 0.01,
'betas': [0.9, 0.98],
'sched': {'name': 'CosineAnnealing', 'warmup_steps': 500, 'constant_steps': 50000, 'min_lr': '2e-5'},
},
}

return model_cfg


@pytest.fixture()
def trainer_cfg():

trainer_cfg = {
'devices': 1,
'num_nodes': 1,
'accelerator': 'gpu',
'precision': 16,
'logger': False,
'enable_checkpointing': False,
'replace_sampler_ddp': False,
'max_epochs': 1,
'max_steps': 150,
'log_every_n_steps': 10,
'val_check_interval': 100,
'limit_val_batches': 50,
'limit_test_batches': 500,
'accumulate_grad_batches': 1,
'gradient_clip_val': 1.0,
}

return trainer_cfg


@pytest.fixture()
def gpt_model(model_cfg, trainer_cfg):

strategy = NLPDDPStrategy()
trainer = Trainer(strategy=strategy, **trainer_cfg)
cfg = DictConfig(model_cfg)

reset_microbatch_calculator()
model = MegatronGPTModel(cfg, trainer)

return model


@pytest.fixture()
def rampup_batch_size():

return [4, 4, 100]


@pytest.fixture()
def rampup_batch_size_schedule():

return [4, 8, 12, 16]


@pytest.mark.run_only_on('GPU')
class TestRampupBatchSize:
@pytest.mark.unit
def test_rampup_bs(self, gpt_model, rampup_batch_size):

assert gpt_model.cfg.rampup_batch_size == rampup_batch_size

@pytest.mark.unit
def test_rampup_bs_schedule(self, gpt_model, trainer_cfg, rampup_batch_size_schedule):

num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
micro_batch_size = gpt_model.cfg.micro_batch_size
num_devices = trainer_cfg["devices"]
num_nodes = trainer_cfg["num_nodes"]
max_steps = trainer_cfg["max_steps"]

global_batch_size_schedule = []
step, consumed_samples = 0, 0
while step <= max_steps:
step += 1
current_global_batch_size = get_num_microbatches() * micro_batch_size * num_devices * num_nodes
consumed_samples += current_global_batch_size
num_microbatch_calculator.update(consumed_samples=consumed_samples, consistency_check=True)

if current_global_batch_size not in global_batch_size_schedule:
global_batch_size_schedule.append(current_global_batch_size)

reset_microbatch_calculator()

assert global_batch_size_schedule == rampup_batch_size_schedule