Skip to content

Commit

Permalink
add rampup batch size support for Megatron GPT (NVIDIA#6424)
Browse files Browse the repository at this point in the history
* added rampup batch size support

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* added tests for rampup batch size

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* fixed the typos

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* added assertions

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* changed assertion rules

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* deleted unused imports

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* changed tests for rampup batch size

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* updated rampup batch size tests

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed styling

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

* rampup batch size tests changes

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>

---------

Signed-off-by: Dmytro Pykhtar <dpykhtar@nvidia.com>
Signed-off-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <dpykhtar@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
  • Loading branch information
4 people authored and hsiehjackson committed Jun 2, 2023
1 parent 639aa5a commit 4a743a1
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ model:
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 4 # limited by GPU memory
global_batch_size: 8 # will use more micro batches to reach global batch size
rampup_batch_size: null # Should be a list of 3 values: [<start_batch_size>, <batch_size_increment>, <rampup_samples>]
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism
virtual_pipeline_model_parallel_size: null # interleaved pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size'),
use_fp8=cfg.get('fp8', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from nemo.utils import logging

try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True
Expand Down Expand Up @@ -427,15 +428,23 @@ def training_step(self, dataloader_iter, batch_idx):
'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1,
)

consumed_samples = self.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
prog_bar=True,
rank_zero_only=True,
batch_size=1,
'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1,
)

if self.cfg.get('rampup_batch_size', None):
micro_batch_size = self.cfg.get('micro_batch_size', 1)
total_gpus_number = self.trainer.num_devices * self.trainer.num_nodes
current_global_batch_size = get_num_microbatches() * micro_batch_size * total_gpus_number
self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1)

num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
num_microbatch_calculator.update(
consumed_samples=consumed_samples, consistency_check=True,
)

return loss_mean

def backward(self, *args, **kwargs):
Expand Down Expand Up @@ -815,6 +824,29 @@ def setup(self, stage=None):
self.init_consumed_samples = init_consumed_samples
self.init_global_step = self.trainer.global_step

rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if rampup_batch_size:
start_batch_size = rampup_batch_size[0]
batch_size_increment = rampup_batch_size[1]
total_gpus_number = self.trainer.num_devices * self.trainer.num_nodes

assert start_batch_size % (total_gpus_number) == 0, (
'expected'
' start batch size ({}) to be divisible by total number of GPUs'
' ({})'.format(start_batch_size, total_gpus_number)
)

micro_batch_size = self.cfg.get('micro_batch_size', 1)
tensor_model_parallel_size = self.cfg.get('tensor_model_parallel_size', 1)
pipeline_model_parallel_size = self.cfg.get('pipeline_model_parallel_size', 1)
total_data_parallel_size = total_gpus_number // (tensor_model_parallel_size * pipeline_model_parallel_size)

assert batch_size_increment % (micro_batch_size * total_data_parallel_size) == 0, (
'expected'
' batch size increment ({}) to be divisible by micro_batch_size ({}) times total data parallel size'
' ({})'.format(batch_size_increment, micro_batch_size, total_data_parallel_size)
)

if stage == 'predict':
return
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank=None,
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
use_fp8=False,
seed=1234,
apex_transformer_log_level=30,
Expand Down Expand Up @@ -121,7 +122,7 @@ def initialize_model_parallel_for_nemo(
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=None,
rampup_batch_size=rampup_batch_size,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches):
Expand Down
195 changes: 195 additions & 0 deletions tests/collections/nlp/test_rampup_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer


from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy

try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

DEVICE_CAPABILITY = None
if torch.cuda.is_available():
DEVICE_CAPABILITY = torch.cuda.get_device_capability()


def reset_microbatch_calculator():
apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None


@pytest.fixture()
def model_cfg(test_data_dir):

model_cfg = {
'precision': 16,
'micro_batch_size': 4,
'global_batch_size': 16,
'rampup_batch_size': [4, 4, 100],
'tensor_model_parallel_size': 1,
'pipeline_model_parallel_size': 1,
'resume_from_checkpoint': None,
'encoder_seq_length': 512,
'max_position_embeddings': 512,
'num_layers': 1,
'hidden_size': 128,
'ffn_hidden_size': 512,
'num_attention_heads': 2,
'init_method_std': 0.02,
'hidden_dropout': 0.1,
'kv_channels': None,
'apply_query_key_layer_scaling': True,
'layernorm_epsilon': 1e-5,
'make_vocab_size_divisible_by': 128,
'pre_process': True,
'post_process': True,
'persist_layer_norm': True,
'gradient_as_bucket_view': True,
'tokenizer': {
'library': 'megatron',
'type': 'GPT2BPETokenizer',
'model': None,
'vocab_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/vocab.json'),
'merge_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/merges.txt'),
'delimiter': None,
},
'native_amp_init_scale': 4294967296,
'native_amp_growth_interval': 1000,
'hysteresis': 2,
'fp32_residual_connection': False,
'fp16_lm_cross_entropy': False,
'megatron_amp_O2': False,
'seed': 1234,
'use_cpu_initialization': False,
'onnx_safe': False,
'apex_transformer_log_level': 30,
'activations_checkpoint_method': None,
'activations_checkpoint_num_layers': 1,
'data': {
'data_prefix': '???',
'index_mapping_dir': None,
'data_impl': 'mmap',
'splits_string': '900,50,50',
'seq_length': 512,
'skip_warmup': True,
'num_workers': 2,
'dataloader_type': 'single',
'reset_position_ids': False,
'reset_attention_mask': False,
'eod_mask_loss': False,
},
'optim': {
'name': 'fused_adam',
'lr': 2e-4,
'weight_decay': 0.01,
'betas': [0.9, 0.98],
'sched': {'name': 'CosineAnnealing', 'warmup_steps': 500, 'constant_steps': 50000, 'min_lr': '2e-5'},
},
}

return model_cfg


@pytest.fixture()
def trainer_cfg():

trainer_cfg = {
'devices': 1,
'num_nodes': 1,
'accelerator': 'gpu',
'precision': 16,
'logger': False,
'enable_checkpointing': False,
'replace_sampler_ddp': False,
'max_epochs': 1,
'max_steps': 150,
'log_every_n_steps': 10,
'val_check_interval': 100,
'limit_val_batches': 50,
'limit_test_batches': 500,
'accumulate_grad_batches': 1,
'gradient_clip_val': 1.0,
}

return trainer_cfg


@pytest.fixture()
def gpt_model(model_cfg, trainer_cfg):

strategy = NLPDDPStrategy()
trainer = Trainer(strategy=strategy, **trainer_cfg)
cfg = DictConfig(model_cfg)

reset_microbatch_calculator()
model = MegatronGPTModel(cfg, trainer)

return model


@pytest.fixture()
def rampup_batch_size():

return [4, 4, 100]


@pytest.fixture()
def rampup_batch_size_schedule():

return [4, 8, 12, 16]


@pytest.mark.run_only_on('GPU')
class TestRampupBatchSize:
@pytest.mark.unit
def test_rampup_bs(self, gpt_model, rampup_batch_size):

assert gpt_model.cfg.rampup_batch_size == rampup_batch_size

@pytest.mark.unit
def test_rampup_bs_schedule(self, gpt_model, trainer_cfg, rampup_batch_size_schedule):

num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
micro_batch_size = gpt_model.cfg.micro_batch_size
num_devices = trainer_cfg["devices"]
num_nodes = trainer_cfg["num_nodes"]
max_steps = trainer_cfg["max_steps"]

global_batch_size_schedule = []
step, consumed_samples = 0, 0
while step <= max_steps:
step += 1
current_global_batch_size = get_num_microbatches() * micro_batch_size * num_devices * num_nodes
consumed_samples += current_global_batch_size
num_microbatch_calculator.update(consumed_samples=consumed_samples, consistency_check=True)

if current_global_batch_size not in global_batch_size_schedule:
global_batch_size_schedule.append(current_global_batch_size)

reset_microbatch_calculator()

assert global_batch_size_schedule == rampup_batch_size_schedule

0 comments on commit 4a743a1

Please sign in to comment.