Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add act checkpoint at sub layer level #720

Merged
merged 15 commits into from
Nov 13, 2023
34 changes: 32 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
build_attn_bias)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
Expand Down Expand Up @@ -733,7 +735,35 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)
cli99 marked this conversation as resolved.
Show resolved Hide resolved

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)

def prepare_inputs_for_generation(
self,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
from composer import Trainer
from composer.utils import get_device
from omegaconf import OmegaConf as om
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
CheckpointWrapper

from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM


@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize(
'activation_checkpointing_target',
[[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
activation_checkpointing_target: list):
device = get_device('gpu')
model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
'attn_config': {
'attn_type': 'grouped_query_attention',
'kv_n_heads': 2,
},
'activation_checkpointing_target': activation_checkpointing_target
}
model_cfg = om.create(model_cfg)

fsdp_config = {
'activation_checkpointing': activation_checkpointing,
'activation_checkpointing_reentrant': False,
'activation_cpu_offload': False,
}

model = ComposerMPTCausalLM(model_cfg)
model = device.module_to_device(model)

trainer = Trainer(
model=model,
device='gpu',
fsdp_config=fsdp_config,
)

assert trainer.state.fsdp_enabled
if not activation_checkpointing:
assert not isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0], CheckpointWrapper)
elif (not activation_checkpointing_target
) or activation_checkpointing_target == [
'mptblock', 'grouped_query_attention'
]:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module, CheckpointWrapper)
elif activation_checkpointing_target == ['grouped_query_attention']:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
else:
raise ValueError(
f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'
)
Loading