-
Notifications
You must be signed in to change notification settings - Fork 531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add act checkpoint at sub layer level #720
Conversation
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two high level comments:
(1) Please add some tests. Not sure how hard it is to test that act ckpt works the way you think it does, but at least test that everything is constructed the way you expect.
(2) Should we add this to all HF models too? I don't see why not. That would go here:
llm-foundry/llmfoundry/models/hf/hf_fsdp.py
Line 200 in 84c86e3
model.activation_checkpointing_fn = lambda module: isinstance( |
(1) are there some existing act ckpt tests to leverage? Or shall we just check to make sure the module names passing works as expected? |
(1) Nope. Probably ok to just test that the module passing works as expected. Would be great if we actually tested activation checkpointing, but that does not exist right now, no. |
@dakinggg re (2) this member method is a member method of the MPT model. enabling this for HF models would requires a custom str to class conversion for a list for every hf model class before it is used in this call... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
I'd wait for @dakinggg approval before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also would like basic unit tests testing the use of this param before merging
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
@dakinggg , added a unit test |
This PR allows users to specify what module to do activation checkpoint through fsdp when
activation_checkpointing
istrue
in fsdp_config. It checks the fieldactivation_checkpointing_target
inmodel_config
, if nothing is specified butactivation_checkpointing
is set totrue
, thenMPTBLOCK
is act checkpointed. To checkpoint the submodules in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY,NORM_CLASS_REGISTRY, add the corresponding names toactivation_checkpointing_target
.For example,
checkpoints the activation of GroupedQueryAttention.
If multiple submodules are up for act checkpointing, add their names to activation_checkpointing_target