Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add act checkpoint at sub layer level #720

Merged
merged 15 commits into from
Nov 13, 2023
Merged

add act checkpoint at sub layer level #720

merged 15 commits into from
Nov 13, 2023

Conversation

cli99
Copy link
Contributor

@cli99 cli99 commented Nov 7, 2023

This PR allows users to specify what module to do activation checkpoint through fsdp when activation_checkpointing is true in fsdp_config. It checks the field activation_checkpointing_target in model_config, if nothing is specified but activation_checkpointing is set to true, then MPTBLOCK is act checkpointed. To checkpoint the submodules in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY,NORM_CLASS_REGISTRY, add the corresponding names to activation_checkpointing_target.
For example,

activation_checkpointing_target: 
    - grouped_query_attention

checkpoints the activation of GroupedQueryAttention.
If multiple submodules are up for act checkpointing, add their names to activation_checkpointing_target

llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two high level comments:
(1) Please add some tests. Not sure how hard it is to test that act ckpt works the way you think it does, but at least test that everything is constructed the way you expect.
(2) Should we add this to all HF models too? I don't see why not. That would go here:

model.activation_checkpointing_fn = lambda module: isinstance(

llmfoundry/models/mpt/modeling_mpt.py Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
@cli99 cli99 requested a review from vchiley November 8, 2023 18:15
@cli99
Copy link
Contributor Author

cli99 commented Nov 8, 2023

Two high level comments: (1) Please add some tests. Not sure how hard it is to test that act ckpt works the way you think it does, but at least test that everything is constructed the way you expect. (2) Should we add this to all HF models too? I don't see why not. That would go here:

model.activation_checkpointing_fn = lambda module: isinstance(

(1) are there some existing act ckpt tests to leverage? Or shall we just check to make sure the module names passing works as expected?
(2) sub-block act checkpointing requires we know the module definitions inside the block , it's model dependent.

@cli99 cli99 requested a review from dakinggg November 8, 2023 18:24
@dakinggg
Copy link
Collaborator

dakinggg commented Nov 8, 2023

(1) Nope. Probably ok to just test that the module passing works as expected. Would be great if we actually tested activation checkpointing, but that does not exist right now, no.
(2) Right, so maybe you have to provide the full import path of the module you want to checkpoint? Is that reasonable?

@vchiley vchiley self-requested a review November 8, 2023 21:13
@cli99 cli99 requested a review from vchiley November 8, 2023 22:57
@vchiley
Copy link
Contributor

vchiley commented Nov 8, 2023

@dakinggg re (2) this member method is a member method of the MPT model.

enabling this for HF models would requires a custom str to class conversion for a list for every hf model class before it is used in this call...
probably outside the scope of this PR 🤷‍♂️
(not sure how maybe you have to provide the full import path of the module you want to checkpoint would work)

Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

I'd wait for @dakinggg approval before merging

@dakinggg
Copy link
Collaborator

dakinggg commented Nov 8, 2023

@vchiley you would do something like import('string name of class') to get the class. but yeah, doesn't need to be part of this PR. @cli99 can you make a jira for adding custom act ckpt to generic HF models and put it in the LLM Foundry backlog jira epic please?

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would like basic unit tests testing the use of this param before merging

llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
@cli99
Copy link
Contributor Author

cli99 commented Nov 9, 2023

@dakinggg , added a unit test

tests/test_fsdp_act_checkpoint.py Outdated Show resolved Hide resolved
@cli99 cli99 enabled auto-merge (squash) November 13, 2023 21:16
@cli99 cli99 merged commit 8ba697c into main Nov 13, 2023
12 checks passed
@cli99 cli99 deleted the cli99/act-checkpoint branch November 13, 2023 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants