Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #1713
Summary
We should not be using
get_adapter_params
on models with activation checkpointing enabled. This is becauseget_adapter_params
iterates overnamed_modules
, notstate_dict
. This is a good thing from a memory perspective, but it also means that the state dict hooks that are usually triggered to unwrap activation checkpointing wrapper modules (or previously FSDP wrapper modules) do not get triggered. We can consider adding an option toget_adapter_params
to do this in that utility cause this is a bit tricky. At the same time, that also feels hacky.Actually this is really only an issue in our single-device recipe when
save_adapter_weights_only=True
-- otherwise we filter the state dict on the keys ofself.adapter_params
(this is correct because (a)self.adapter_params
is defined before any AC wrapping and (b) we are using the state dict so the relevant hooks have fired).But actually, we don't need to use
get_adapter_params
here at all: we already have references to all the adapter weights and the correct (non-AC-wrapped) keys inself.adapter_params
as defined here. So we can just use that, and we can do it regardless of the value ofsave_adapter_weights_only
(which makes the code a bit cleaner).Test plan
Why didn't we catch this before? Well actually all of our existing recipes tests set
enable_activation_checkpointing=False
so we missed it. For that reason I've updated thetest_recipe_state_on_resume
tests to haveenable_activation_checkpointing=True
. Running those tests on main with the new values forenable_activation_checkpointing
:If we instead run with the changes on this PR:
Also, the command from #1713 that was failing before now passes: