From 158d10b4b9d530091918ee8c5d0d87e66fa3cd08 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Thu, 13 Apr 2023 07:55:00 -0700 Subject: [PATCH 01/11] added rampup batch size support Signed-off-by: Dmytro Pykhtar --- .../conf/megatron_gpt_config.yaml | 1 + .../language_modeling/megatron_base_model.py | 1 + .../language_modeling/megatron_gpt_model.py | 19 ++++++++++++++----- .../modules/common/megatron/megatron_init.py | 3 ++- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 27cb3af3ce91..0ab0feb2c114 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -45,6 +45,7 @@ model: # gradient accumulation will be done automatically based on data_parallel_size micro_batch_size: 4 # limited by GPU memory global_batch_size: 8 # will use more micro batches to reach global batch size + rampup_batch_size: null # Should be a list of 3 values: [, , ] tensor_model_parallel_size: 1 # intra-layer model parallelism pipeline_model_parallel_size: 1 # inter-layer model parallelism virtual_pipeline_model_parallel_size: null # interleaved pipeline diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 97b500ce30bf..7d9058e764aa 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -107,6 +107,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), micro_batch_size=cfg.get('micro_batch_size'), global_batch_size=cfg.get('global_batch_size'), + rampup_batch_size=cfg.get('rampup_batch_size'), use_fp8=cfg.get('fp8', False), seed=self.cfg.get('seed', 1234), apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 19dc724c4043..42650f541320 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -54,6 +54,7 @@ from nemo.utils import logging try: + import apex.transformer.pipeline_parallel.utils from apex.transformer import parallel_state from apex.transformer.pipeline_parallel.schedules.common import build_model from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining @@ -63,6 +64,7 @@ from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( forward_backward_pipelining_without_interleaving, ) + from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True except (ImportError, ModuleNotFoundError): @@ -446,15 +448,22 @@ def training_step(self, dataloader_iter, batch_idx): 'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1, ) + consumed_samples = self.compute_consumed_samples(self.trainer.global_step - self.init_global_step) # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( - 'consumed_samples', - self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), - prog_bar=True, - rank_zero_only=True, - batch_size=1, + 'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1, ) + if self.cfg.get('rampup_batch_size', None): + micro_batch_size = self.cfg.get('micro_batch_size', 1) + current_global_batch_size = get_num_microbatches() * micro_batch_size + self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1) + + num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR + num_microbatch_calculator.update( + consumed_samples=consumed_samples, consistency_check=True, + ) + return loss_mean def backward(self, *args, **kwargs): diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index e7e3aeec31da..722f5f45b00e 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -59,6 +59,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_split_rank=None, micro_batch_size=None, global_batch_size=None, + rampup_batch_size=None, use_fp8=False, seed=1234, apex_transformer_log_level=30, @@ -115,7 +116,7 @@ def initialize_model_parallel_for_nemo( global_batch_size=global_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, - rampup_batch_size=None, + rampup_batch_size=rampup_batch_size, ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches): From c4f3f6e1e9cfe766ad943e12f67b3008e3044116 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Fri, 14 Apr 2023 05:05:48 -0700 Subject: [PATCH 02/11] added tests for rampup batch size Signed-off-by: Dmytro Pykhtar --- .../collections/nlp/test_rampup_bath_size.py | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 tests/collections/nlp/test_rampup_bath_size.py diff --git a/tests/collections/nlp/test_rampup_bath_size.py b/tests/collections/nlp/test_rampup_bath_size.py new file mode 100644 index 000000000000..d4de2760832e --- /dev/null +++ b/tests/collections/nlp/test_rampup_bath_size.py @@ -0,0 +1,158 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy + +DEVICE_CAPABILITY = None +if torch.cuda.is_available(): + DEVICE_CAPABILITY = torch.cuda.get_device_capability() + + +@pytest.fixture() +def model_cfg(test_data_dir): + + model_cfg = { + 'precision': 16, + 'micro_batch_size': 4, + 'global_batch_size': 8, + 'rampup_batch_size': [4, 2, 100], + 'tensor_model_parallel_size': 1, + 'pipeline_model_parallel_size': 1, + 'resume_from_checkpoint': None, + 'encoder_seq_length': 512, + 'max_position_embeddings': 512, + 'num_layers': 1, + 'hidden_size': 128, + 'ffn_hidden_size': 512, + 'num_attention_heads': 2, + 'init_method_std': 0.02, + 'hidden_dropout': 0.1, + 'kv_channels': None, + 'apply_query_key_layer_scaling': True, + 'layernorm_epsilon': 1e-5, + 'make_vocab_size_divisible_by': 128, + 'pre_process': True, + 'post_process': True, + 'persist_layer_norm': True, + 'gradient_as_bucket_view': True, + 'tokenizer': { + 'library': 'megatron', + 'type': 'GPT2BPETokenizer', + 'model': None, + 'vocab_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/vocab.json'), + 'merge_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/merges.txt'), + 'delimiter': None, + }, + 'native_amp_init_scale': 4294967296, + 'native_amp_growth_interval': 1000, + 'hysteresis': 2, + 'fp32_residual_connection': False, + 'fp16_lm_cross_entropy': False, + 'megatron_amp_O2': False, + 'seed': 1234, + 'use_cpu_initialization': False, + 'onnx_safe': False, + 'apex_transformer_log_level': 30, + 'activations_checkpoint_method': None, + 'activations_checkpoint_num_layers': 1, + 'data': { + 'data_prefix': '???', + 'index_mapping_dir': None, + 'data_impl': 'mmap', + 'splits_string': '900,50,50', + 'seq_length': 512, + 'skip_warmup': True, + 'num_workers': 2, + 'dataloader_type': 'single', + 'reset_position_ids': False, + 'reset_attention_mask': False, + 'eod_mask_loss': False, + }, + 'optim': { + 'name': 'fused_adam', + 'lr': 2e-4, + 'weight_decay': 0.01, + 'betas': [0.9, 0.98], + 'sched': {'name': 'CosineAnnealing', 'warmup_steps': 500, 'constant_steps': 50000, 'min_lr': '2e-5'}, + }, + } + return model_cfg + + +@pytest.fixture() +def trainer_cfg(): + + trainer_cfg = { + 'devices': 1, + 'num_nodes': 1, + 'accelerator': 'gpu', + 'precision': 16, + 'logger': False, + 'enable_checkpointing': False, + 'replace_sampler_ddp': False, + 'max_epochs': 1000, + 'max_steps': 100000, + 'log_every_n_steps': 10, + 'val_check_interval': 100, + 'limit_val_batches': 50, + 'limit_test_batches': 500, + 'accumulate_grad_batches': 1, + 'gradient_clip_val': 1.0, + } + + return trainer_cfg + + +@pytest.fixture() +def gpt_model(model_cfg, trainer_cfg): + strategy = NLPDDPStrategy() + trainer = Trainer(strategy=strategy, **trainer_cfg) + cfg = DictConfig(model_cfg) + + try: + model = MegatronGPTModel(cfg, trainer) + return model + except Exception as exception: + return exception + + +@pytest.fixture() +def rampup_batch_size(): + return [4, 2, 100] + + +@pytest.fixture() +def rampup_batch_size_invalid(): + return "Microbatch calculator already initialized." + + +@pytest.mark.run_only_on('GPU') +class TestRampupBatchSize: + @pytest.mark.unit + def test_rampup_bs(self, gpt_model, rampup_batch_size): + assert gpt_model.cfg.rampup_batch_size == rampup_batch_size + + @pytest.mark.unit + def test_invalid_rampup_bs(self, gpt_model, rampup_batch_size_invalid): + assert str(gpt_model) == rampup_batch_size_invalid From 67b7cfc16d23b7ae2a7267eebed45cd13f06b266 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Sat, 15 Apr 2023 13:53:31 -0700 Subject: [PATCH 03/11] fixed the typos Signed-off-by: Dmytro Pykhtar --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 2 +- .../nlp/{test_rampup_bath_size.py => test_rampup_batch_size.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/collections/nlp/{test_rampup_bath_size.py => test_rampup_batch_size.py} (100%) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 0ab0feb2c114..09b30c08dd47 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -45,7 +45,7 @@ model: # gradient accumulation will be done automatically based on data_parallel_size micro_batch_size: 4 # limited by GPU memory global_batch_size: 8 # will use more micro batches to reach global batch size - rampup_batch_size: null # Should be a list of 3 values: [, , ] + rampup_batch_size: null # Should be a list of 3 values: [, , ] tensor_model_parallel_size: 1 # intra-layer model parallelism pipeline_model_parallel_size: 1 # inter-layer model parallelism virtual_pipeline_model_parallel_size: null # interleaved pipeline diff --git a/tests/collections/nlp/test_rampup_bath_size.py b/tests/collections/nlp/test_rampup_batch_size.py similarity index 100% rename from tests/collections/nlp/test_rampup_bath_size.py rename to tests/collections/nlp/test_rampup_batch_size.py From b6a306f144546dcaa49e854265dee57b946d08f2 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Sat, 15 Apr 2023 15:39:03 -0700 Subject: [PATCH 04/11] added assertions Signed-off-by: Dmytro Pykhtar --- .../language_modeling/megatron_gpt_model.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 42650f541320..f39172f0bde4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -456,7 +456,8 @@ def training_step(self, dataloader_iter, batch_idx): if self.cfg.get('rampup_batch_size', None): micro_batch_size = self.cfg.get('micro_batch_size', 1) - current_global_batch_size = get_num_microbatches() * micro_batch_size + total_gpus_number = self.trainer.num_devices * self.trainer.num_nodes + current_global_batch_size = get_num_microbatches() * micro_batch_size * total_gpus_number self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1) num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR @@ -840,6 +841,28 @@ def setup(self, stage=None): self.init_consumed_samples = init_consumed_samples self.init_global_step = self.trainer.global_step + rampup_batch_size = self.cfg.get('rampup_batch_size', None) + if rampup_batch_size: + start_batch_size = rampup_batch_size[0] + batch_size_increment = rampup_batch_size[1] + total_gpus_number = self.trainer.num_devices * self.trainer.num_nodes + + assert start_batch_size % (total_gpus_number) == 0, ( + 'expected' + ' start batch size ({}) to be divisible by total number of GPUs' + ' ({})'.format(start_batch_size, total_gpus_number) + ) + + tensor_model_parallel_size = self.cfg.get('tensor_model_parallel_size', 1) + pipeline_model_parallel_size = self.cfg.get('pipeline_model_parallel_size', 1) + total_data_parallel_size = total_gpus_number // (tensor_model_parallel_size * pipeline_model_parallel_size) + + assert batch_size_increment % (total_data_parallel_size) == 0, ( + 'expected' + ' batch size increment ({}) to be divisible by total data parallel size' + ' ({})'.format(batch_size_increment, total_data_parallel_size) + ) + if stage == 'predict': return else: From 81e23e6e21bf0b49b2991f8031980ba5a88778eb Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Tue, 18 Apr 2023 04:54:22 -0700 Subject: [PATCH 05/11] changed assertion rules Signed-off-by: Dmytro Pykhtar --- .../nlp/models/language_modeling/megatron_gpt_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f39172f0bde4..65ed78353550 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -853,14 +853,15 @@ def setup(self, stage=None): ' ({})'.format(start_batch_size, total_gpus_number) ) + micro_batch_size = self.cfg.get('micro_batch_size', 1) tensor_model_parallel_size = self.cfg.get('tensor_model_parallel_size', 1) pipeline_model_parallel_size = self.cfg.get('pipeline_model_parallel_size', 1) total_data_parallel_size = total_gpus_number // (tensor_model_parallel_size * pipeline_model_parallel_size) - assert batch_size_increment % (total_data_parallel_size) == 0, ( + assert batch_size_increment % (micro_batch_size * total_data_parallel_size) == 0, ( 'expected' - ' batch size increment ({}) to be divisible by total data parallel size' - ' ({})'.format(batch_size_increment, total_data_parallel_size) + ' batch size increment ({}) to be divisible by micro_batch_size ({}) times total data parallel size' + ' ({})'.format(batch_size_increment, micro_batch_size, total_data_parallel_size) ) if stage == 'predict': From 147208337fe7c158cb01413abc51e02bad2317b9 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Tue, 18 Apr 2023 05:44:40 -0700 Subject: [PATCH 06/11] deleted unused imports Signed-off-by: Dmytro Pykhtar --- tests/collections/nlp/test_rampup_batch_size.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index d4de2760832e..f9e33e47a599 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -19,9 +19,7 @@ from omegaconf import DictConfig from pytorch_lightning import Trainer -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy DEVICE_CAPABILITY = None From e8b05de61e2185e6f7d0fd5c1ea3450b01b46a86 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Tue, 18 Apr 2023 06:17:52 -0700 Subject: [PATCH 07/11] changed tests for rampup batch size Signed-off-by: Dmytro Pykhtar --- .../collections/nlp/test_rampup_batch_size.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index f9e33e47a599..002190bd5402 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -95,6 +95,7 @@ def model_cfg(test_data_dir): 'sched': {'name': 'CosineAnnealing', 'warmup_steps': 500, 'constant_steps': 50000, 'min_lr': '2e-5'}, }, } + return model_cfg @@ -128,21 +129,15 @@ def gpt_model(model_cfg, trainer_cfg): trainer = Trainer(strategy=strategy, **trainer_cfg) cfg = DictConfig(model_cfg) - try: - model = MegatronGPTModel(cfg, trainer) - return model - except Exception as exception: - return exception + model = MegatronGPTModel(cfg, trainer) + + return model @pytest.fixture() def rampup_batch_size(): - return [4, 2, 100] - -@pytest.fixture() -def rampup_batch_size_invalid(): - return "Microbatch calculator already initialized." + return [4, 2, 100] @pytest.mark.run_only_on('GPU') @@ -150,7 +145,3 @@ class TestRampupBatchSize: @pytest.mark.unit def test_rampup_bs(self, gpt_model, rampup_batch_size): assert gpt_model.cfg.rampup_batch_size == rampup_batch_size - - @pytest.mark.unit - def test_invalid_rampup_bs(self, gpt_model, rampup_batch_size_invalid): - assert str(gpt_model) == rampup_batch_size_invalid From 9b8237f541282f7921ac25ea10c29ba10f726e57 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Tue, 18 Apr 2023 10:41:27 -0700 Subject: [PATCH 08/11] updated rampup batch size tests Signed-off-by: Dmytro Pykhtar --- .../collections/nlp/test_rampup_batch_size.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index 002190bd5402..c18be6768b4a 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -19,9 +19,20 @@ from omegaconf import DictConfig from pytorch_lightning import Trainer + from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +try: + import apex.transformer.pipeline_parallel.utils + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + DEVICE_CAPABILITY = None if torch.cuda.is_available(): DEVICE_CAPABILITY = torch.cuda.get_device_capability() @@ -33,8 +44,8 @@ def model_cfg(test_data_dir): model_cfg = { 'precision': 16, 'micro_batch_size': 4, - 'global_batch_size': 8, - 'rampup_batch_size': [4, 2, 100], + 'global_batch_size': 16, + 'rampup_batch_size': [4, 4, 100], 'tensor_model_parallel_size': 1, 'pipeline_model_parallel_size': 1, 'resume_from_checkpoint': None, @@ -110,8 +121,8 @@ def trainer_cfg(): 'logger': False, 'enable_checkpointing': False, 'replace_sampler_ddp': False, - 'max_epochs': 1000, - 'max_steps': 100000, + 'max_epochs': 1, + 'max_steps': 150, 'log_every_n_steps': 10, 'val_check_interval': 100, 'limit_val_batches': 50, @@ -125,6 +136,7 @@ def trainer_cfg(): @pytest.fixture() def gpt_model(model_cfg, trainer_cfg): + strategy = NLPDDPStrategy() trainer = Trainer(strategy=strategy, **trainer_cfg) cfg = DictConfig(model_cfg) @@ -137,11 +149,36 @@ def gpt_model(model_cfg, trainer_cfg): @pytest.fixture() def rampup_batch_size(): - return [4, 2, 100] + return [4, 4, 100] + +@pytest.fixture() +def rampup_batch_size_schedule(): + + return [4, 8, 12, 16] @pytest.mark.run_only_on('GPU') class TestRampupBatchSize: @pytest.mark.unit - def test_rampup_bs(self, gpt_model, rampup_batch_size): + def test_rampup_bs(self, gpt_model, trainer_cfg, rampup_batch_size, rampup_batch_size_schedule): + + num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR + micro_batch_size = gpt_model.cfg.micro_batch_size + num_devices = trainer_cfg["devices"] + num_nodes = trainer_cfg["num_nodes"] + max_steps = trainer_cfg["max_steps"] + + global_batch_size_schedule = [] + step, consumed_samples = 0, 0 + while step <= max_steps: + step += 1 + current_global_batch_size = get_num_microbatches() * micro_batch_size * num_devices * num_nodes + consumed_samples += current_global_batch_size + num_microbatch_calculator.update(consumed_samples=consumed_samples, consistency_check=True) + + if current_global_batch_size not in global_batch_size_schedule: + global_batch_size_schedule.append(current_global_batch_size) + assert gpt_model.cfg.rampup_batch_size == rampup_batch_size + assert global_batch_size_schedule == rampup_batch_size_schedule + From 94a53893b62b1de329d06fd726bc6586d4d12e31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Apr 2023 17:42:44 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/collections/nlp/test_rampup_batch_size.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index c18be6768b4a..769d85d26820 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -151,6 +151,7 @@ def rampup_batch_size(): return [4, 4, 100] + @pytest.fixture() def rampup_batch_size_schedule(): @@ -175,10 +176,9 @@ def test_rampup_bs(self, gpt_model, trainer_cfg, rampup_batch_size, rampup_batch current_global_batch_size = get_num_microbatches() * micro_batch_size * num_devices * num_nodes consumed_samples += current_global_batch_size num_microbatch_calculator.update(consumed_samples=consumed_samples, consistency_check=True) - + if current_global_batch_size not in global_batch_size_schedule: global_batch_size_schedule.append(current_global_batch_size) assert gpt_model.cfg.rampup_batch_size == rampup_batch_size assert global_batch_size_schedule == rampup_batch_size_schedule - From 3fffcd19e6e8b2b4ba696eb0247b2190281da62d Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Tue, 18 Apr 2023 10:45:30 -0700 Subject: [PATCH 10/11] fixed styling Signed-off-by: Dmytro Pykhtar --- tests/collections/nlp/test_rampup_batch_size.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index c18be6768b4a..769d85d26820 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -151,6 +151,7 @@ def rampup_batch_size(): return [4, 4, 100] + @pytest.fixture() def rampup_batch_size_schedule(): @@ -175,10 +176,9 @@ def test_rampup_bs(self, gpt_model, trainer_cfg, rampup_batch_size, rampup_batch current_global_batch_size = get_num_microbatches() * micro_batch_size * num_devices * num_nodes consumed_samples += current_global_batch_size num_microbatch_calculator.update(consumed_samples=consumed_samples, consistency_check=True) - + if current_global_batch_size not in global_batch_size_schedule: global_batch_size_schedule.append(current_global_batch_size) assert gpt_model.cfg.rampup_batch_size == rampup_batch_size assert global_batch_size_schedule == rampup_batch_size_schedule - From 9bd543a1d5b65a5e397fa7e268d12ace0ecf9ae4 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar Date: Tue, 18 Apr 2023 17:01:51 -0700 Subject: [PATCH 11/11] rampup batch size tests changes Signed-off-by: Dmytro Pykhtar --- tests/collections/nlp/test_rampup_batch_size.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index 769d85d26820..86af6bf51e1d 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -38,6 +38,10 @@ DEVICE_CAPABILITY = torch.cuda.get_device_capability() +def reset_microbatch_calculator(): + apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + + @pytest.fixture() def model_cfg(test_data_dir): @@ -141,6 +145,7 @@ def gpt_model(model_cfg, trainer_cfg): trainer = Trainer(strategy=strategy, **trainer_cfg) cfg = DictConfig(model_cfg) + reset_microbatch_calculator() model = MegatronGPTModel(cfg, trainer) return model @@ -161,7 +166,12 @@ def rampup_batch_size_schedule(): @pytest.mark.run_only_on('GPU') class TestRampupBatchSize: @pytest.mark.unit - def test_rampup_bs(self, gpt_model, trainer_cfg, rampup_batch_size, rampup_batch_size_schedule): + def test_rampup_bs(self, gpt_model, rampup_batch_size): + + assert gpt_model.cfg.rampup_batch_size == rampup_batch_size + + @pytest.mark.unit + def test_rampup_bs_schedule(self, gpt_model, trainer_cfg, rampup_batch_size_schedule): num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR micro_batch_size = gpt_model.cfg.micro_batch_size @@ -180,5 +190,6 @@ def test_rampup_bs(self, gpt_model, trainer_cfg, rampup_batch_size, rampup_batch if current_global_batch_size not in global_batch_size_schedule: global_batch_size_schedule.append(current_global_batch_size) - assert gpt_model.cfg.rampup_batch_size == rampup_batch_size + reset_microbatch_calculator() + assert global_batch_size_schedule == rampup_batch_size_schedule