Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add act checkpoint at sub layer level #720

Merged
merged 15 commits into from
Nov 13, 2023
33 changes: 31 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
build_attn_bias)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
Expand Down Expand Up @@ -705,7 +707,34 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
if not hasattr(self.config, 'activation_checkpointing_target'
) or self.config.activation_checkpointing_target is None:
log.info(
cli99 marked this conversation as resolved.
Show resolved Hide resolved
f'activation checkpointing MPTBlock as activation_checkpointing_target is not set in model_config'
)
return isinstance(module, MPTBlock)
cli99 marked this conversation as resolved.
Show resolved Hide resolved
else:
act_ckpt_list = self.config.activation_checkpointing_target
if 'MPTBlock' in act_ckpt_list:
act_ckpt_list = ['MPTBlock']
warnings.warn(
vchiley marked this conversation as resolved.
Show resolved Hide resolved
f'activation checkpointing MPTBlock, ignoring other sub-block modules if specified'
)
mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
warnings.warn(
vchiley marked this conversation as resolved.
Show resolved Hide resolved
f'module name specified in activation_checkpointing_target ({mod_name}) not recognized, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.'
)
return isinstance(module, mod_types)

def prepare_inputs_for_generation(
self,
Expand Down
Loading