Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable LoRA + FSDP2 #855
enable LoRA + FSDP2 #855
Changes from 2 commits
e5826a1
64fc870
0cd21c6
589191e
c801f26
19a2d70
441da10
750b9e5
3d632d5
cb3abb3
e68804a
d6af9a2
b616394
a400497
e9de63c
05d3895
7a5bb80
64bf49c
cb1bba4
ac516e9
bfde704
102db31
0b66651
672aabb
6af2723
42ad99c
74f6175
f1b8a5e
36e6829
08cd1fd
559bc4d
2333134
49a0364
dc2ce02
0a604aa
fa83140
4b5a895
a2e34ec
6142031
7607e14
1899beb
c1cfabb
d7382ae
d1ff53b
1eb9e87
695e959
e10f638
b1e3d30
944a723
ac5f7aa
d769626
f90c3cc
42ef49a
170de94
f8a7018
a3b2f3e
1a692b3
8fbbc4b
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry not able to comment above, but the docstring of this function should be updated since we're no longer initializing on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the docstring used to be
Instantiating Model on CPU
(left) and I removed the mention of CPU. I did not mention meta device because it measures meta init + checkpoing loading now. Happy to improve if you are referring to this docstringThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh. Just got you point. Updated docstring for
_setup_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(m, modules.TransformerDecoderLayer):
is equivalent ofauto_wrap_policy
in FSDP1There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the noob question, but can you help me understand what's going on here? Why do I need to
full_shard
theTransformerDecoderLayer
and then callfully_shard
on the model?An unrelated question: if I have enough GPU memory, should I be thinking about using something similar to
SHARD_GRAD_OP
with FSDP2?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In FSDP1, we wrap each
TransformerDecoderLayer
and then rootmodel
as well. It's blackboxed inauto_wrap_policy=utils.lora_fsdp_wrap_policy(modules_to_wrap={modules.TransformerDecoderLayer})
In FSDP2, we un-blackboxed it to this for-loop. It you perfer, this can be factored into a util function in torchtune so user call
util.fully_shard(model, modules_to_wrap)
Personally I have bias towards un-blackboxed approach since people can modify the for-loop to achieve different wrapping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the equivalence SHARD_GRAD_OP in FSDP2 is
reshard_after_forward=False
. Do you want it as a config in .yaml?fully_shard(model, reshard_after_forward=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explanation! I love the un-blackboxed approach here - just needs more documentation and explanation :) After reading the FSDP2 RFC, this became a lot clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pros and cons of meta init. pros is 4.5x speed up during model init and thus shorter TTFB. cons is user need to call
initialize_parameters
onLoRALinear
explicitly to move them frommeta
togpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because these params are not being loaded from checkpoint? Or do I misunderstand?
If this is indeed the reason, how do we handle this code block when the LoRA params are being loaded from checkpoint (eg: when resuming training)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right. when finetuning from a original HF checkpoint,
lora_weights_state_dict = None
for resuming training,
lora_weights_state_dict is not None
and we avoided callingm.initialize_parameters()
againThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:
fully_shard
initialize_parameters
andreset_parameters
Also I think there was a technical reason with FSDP1 to call the function
reset_parameters
. Is that still true? Or can we standardize this withinitialize_parameters
in the modules code? Happy to chat about this offline!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! will add comment to explain
fully_shard
, meta init, andreset/initialize_parameters
FSDP1 calls
reset_parameters
for the exact same reason FSDP2 callreset/initialize_parameters
: RoPE are not covered in checkpoints,lora_a
andlora_b
are not covered in checkpoints forresume_training=False
It's just FSDP1 have a contract to call overrided
nn.Module.reset_parameter
throughFSDP(model, param_init=)
, but FSDP2 does not impose overridingreset_parameter
. now use can name itreset_parameter
orinitialize_parameters