Skip to content

Commit

Permalink
Akoumparouli/mcore microbatch calculator fix (#10780)
Browse files Browse the repository at this point in the history
* move tests/lightning/{,_}io

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* add microbatch calculator context manager

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* use microbatch calculator context manager

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* add on_load_checkpoint test to ValidateModelRestoration; use ctx manager to reconfigure microbatch calculator; update save/restore path; add cleanup step at the end

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove unused var

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa and akoumpa authored Oct 7, 2024
1 parent edb06ae commit 8a238b8
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 163 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
27 changes: 27 additions & 0 deletions tests/lightning/mcore_microbatch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import contextlib


# @akoumparouli: use a context manager that saves/restores gbs/mbs when using
# reconfigure_num_microbatches_calculator to avoid interference between tests.
@contextlib.contextmanager
def reconfigure_num_microbatches_calculator_manager(*args, **kwargs):
import megatron.core.num_microbatches_calculator as mb_calc

# Store current mbs, gbs values
if not mb_calc._GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
_mbs = mb_calc.get_micro_batch_size()
_gbs = mb_calc.get_current_global_batch_size()

# use user's settings
mb_calc.reconfigure_num_microbatches_calculator(*args, **kwargs)
else:
_mbs, _gbs = 1, 1

try:
# run user's code
yield
# @akoumparouli: no catch
finally:
# restore old mbs, gbs
if not mb_calc._GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
mb_calc.reconfigure_num_microbatches_calculator(0, None, _gbs, _mbs, data_parallel_size=1)
137 changes: 66 additions & 71 deletions tests/lightning/test_dist_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def set_env():
import pytest
import pytorch_lightning as pl
import torch
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator

import nemo.lightning as nl
from nemo.collections import llm
Expand All @@ -43,13 +42,9 @@ def _get_last_checkpoint_dir(model: pl.LightningModule, suffix: str = '') -> Pat
return f'epoch={model.trainer.current_epoch - 1}-step={model.trainer.max_steps - 1}{suffix}'


def get_model_and_data():
micro_batch_size = 2
global_batch_size = 2
def get_model_and_data(mbs=2, gbs=2):
seq_length = 128
data = llm.MockDataModule(
seq_length=seq_length, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size
)
data = llm.MockDataModule(seq_length=seq_length, micro_batch_size=mbs, global_batch_size=gbs)

config = llm.GPTConfig(
num_layers=2,
Expand All @@ -59,13 +54,6 @@ def get_model_and_data():
seq_length=seq_length,
apply_query_key_layer_scaling=1,
)
reconfigure_num_microbatches_calculator(
0,
None,
global_batch_size,
micro_batch_size,
data_parallel_size=1,
)
return llm.GPTModel(config, tokenizer=data.tokenizer), data


Expand All @@ -76,21 +64,25 @@ def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path):

set_env()
assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1'
model, data = get_model_and_data()
gbs, mbs = 2, 2
model, data = get_model_and_data(mbs, gbs)
from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager

strategy = _get_strategy()
with reconfigure_num_microbatches_calculator_manager(0, None, gbs, mbs, data_parallel_size=1):

trainer = nl.Trainer(
devices=1,
accelerator="gpu",
strategy=strategy,
enable_checkpointing=True,
max_steps=2,
default_root_dir=str(tmp_path),
logger=False,
)
strategy = _get_strategy()

trainer = nl.Trainer(
devices=1,
accelerator="gpu",
strategy=strategy,
enable_checkpointing=True,
max_steps=2,
default_root_dir=str(tmp_path),
logger=False,
)

trainer.fit(model, data)
trainer.fit(model, data)

assert isinstance(trainer.strategy.checkpoint_io, MegatronCheckpointIO)
# Ckpt path doesn't contain the .ckpt suffix
Expand All @@ -104,51 +96,54 @@ def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path):
def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path):
set_env()
assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1'
model, data = get_model_and_data()

sync_ckpt_dir = tmp_path / 'sync_checkpoints'
async_ckpt_dir = tmp_path / 'async_checkpoints'

sync_checkpoint_io = MegatronCheckpointIO('torch_dist')
async_checkpoint_io = AsyncFinalizableCheckpointIO(MegatronCheckpointIO('torch_dist', async_save=True))

# dummy_trainer just to initialize NCCL
dummy_trainer = pl.Trainer(
devices=1,
logger=False,
max_steps=2,
strategy=_get_strategy(),
)
dummy_trainer.fit(model, data)
strategy = _get_strategy()
tmp_path = strategy.broadcast(tmp_path)

## reset the model and data and train with sync checkpointing
model, data = get_model_and_data()
sync_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[sync_checkpoint_io],
default_root_dir=str(sync_ckpt_dir),
)
sync_test_trainer.fit(model, data)

## reset the model and data and train with sync checkpointing
model, data = get_model_and_data()
async_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[async_checkpoint_io],
callbacks=AsyncFinalizerCallback(),
default_root_dir=str(async_ckpt_dir),
)
async_test_trainer.fit(model, data)
gbs, mbs = 2, 2
model, data = get_model_and_data(mbs, gbs)
from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager

with reconfigure_num_microbatches_calculator_manager(0, None, gbs, mbs, data_parallel_size=1):

sync_ckpt_dir = tmp_path / 'sync_checkpoints'
async_ckpt_dir = tmp_path / 'async_checkpoints'

sync_checkpoint_io = MegatronCheckpointIO('torch_dist')
async_checkpoint_io = AsyncFinalizableCheckpointIO(MegatronCheckpointIO('torch_dist', async_save=True))

# dummy_trainer just to initialize NCCL
dummy_trainer = pl.Trainer(
devices=1,
logger=False,
max_steps=2,
strategy=_get_strategy(),
)
dummy_trainer.fit(model, data)
strategy = _get_strategy()

## reset the model and data and train with sync checkpointing
model, data = get_model_and_data(mbs, gbs)
sync_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[sync_checkpoint_io],
default_root_dir=str(sync_ckpt_dir),
)
sync_test_trainer.fit(model, data)

## reset the model and data and train with sync checkpointing
model, data = get_model_and_data(mbs, gbs)
async_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[async_checkpoint_io],
callbacks=AsyncFinalizerCallback(),
default_root_dir=str(async_ckpt_dir),
)
async_test_trainer.fit(model, data)

checkpoint = {'sharded_state_dict': model.sharded_state_dict()}

Expand Down
55 changes: 31 additions & 24 deletions tests/lightning/test_nemo_resume_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def set_env():

import pytest
import torch
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator
from megatron.core.optimizer import OptimizerConfig

import nemo.lightning as nl
Expand Down Expand Up @@ -90,7 +89,7 @@ def compare_ckpts(a, b, path=[]):
raise ValueError("Unexpected value type " + str(type(a)))


def setup_data_model_optim(log_dir, n_steps, data_path, gbs=2, mbs=1):
def setup_data(log_dir, n_steps, data_path, gbs=2, mbs=1):
seq_length = 2048
tokenizer = get_nmt_tokenizer(
"megatron",
Expand All @@ -108,14 +107,11 @@ def setup_data_model_optim(log_dir, n_steps, data_path, gbs=2, mbs=1):
tokenizer=tokenizer,
split='9999,1,1',
)
# Other tests might have different configs, so need to configure explicitly.
reconfigure_num_microbatches_calculator(
0,
None,
gbs,
mbs,
data_parallel_size=1,
)
return data


def setup_model_optim(log_dir, n_steps, tokenizer, gbs=2, mbs=1):
seq_length = 2048
gpt_config = llm.GPTConfig(
num_layers=2,
hidden_size=128,
Expand All @@ -131,7 +127,7 @@ def setup_data_model_optim(log_dir, n_steps, data_path, gbs=2, mbs=1):
masked_softmax_fusion=False,
)

model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)
model = llm.GPTModel(gpt_config, tokenizer=tokenizer)

opt_config = OptimizerConfig(
optimizer='adam',
Expand All @@ -148,7 +144,7 @@ def setup_data_model_optim(log_dir, n_steps, data_path, gbs=2, mbs=1):
)
optim = MegatronOptimizerModule(config=opt_config)

return gpt_config, data, model, optim
return gpt_config, model, optim


def setup_trainer_and_logger(log_dir):
Expand Down Expand Up @@ -248,18 +244,29 @@ def train(n_steps, resume):
log_dir = f'/tmp/mcore_logs_{n_steps}steps'
os.makedirs(log_dir, exist_ok=True)
data_path = [DATA_PATH]
gpt_config, data, model, optim = setup_data_model_optim(log_dir, n_steps, data_path)
trainer, nemo_logger = setup_trainer_and_logger(log_dir)
llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
resume=resume,
tokenizer='data',
optim=optim,
)
trainer._teardown()
data = setup_data(log_dir, n_steps, data_path, gbs=2, mbs=1)
# Other tests might have different configs, so need to configure explicitly.
from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager

with reconfigure_num_microbatches_calculator_manager(
0,
None,
2, # gbs
1, # mbs
data_parallel_size=1,
):
gpt_config, model, optim = setup_model_optim(log_dir, n_steps, data.tokenizer)
trainer, nemo_logger = setup_trainer_and_logger(log_dir)
llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
resume=resume,
tokenizer='data',
optim=optim,
)
trainer._teardown()

set_env()
assert os.environ['NVTE_FLASH_ATTN'] == '0'
Expand Down
Loading

0 comments on commit 8a238b8

Please sign in to comment.