-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save adapter config and remapped adapter weights for loading into PEFT #933
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/933
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9eb9b68 with merge base 29ae975 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good - just needs some clean up and comments to make this easy to understand
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp | ||
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR, but maybe at some point we should consider replacing the apply_lora_to_*
flags with just adding mlp
and output
to the lora_modules
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed, I think this is likely where we'll head eventually. One thing is that we will probably want to make LoRA in MLP more configurable (i.e. use w1
, w2
, w3
(or hopefully more descriptive names) instead of mlp
). Otherwise the relationship between e.g. q_proj
(nn.Linear) and mlp
(FeedForward) being in the same config is a bit confusing. Anyways this shouldn't be a huge effort to change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a single list is more intuitive, since, AFAICT, this is just consolidated into a single list under the hood.
or hopefully more descriptive names
Changing names later on can invalidate the saved checkpoints, so would require some versioning for backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess versioning or some sort of a convertor/mapping? It would be great to figure this change out soon, but this point about checkpoint invalidation is a good one and something we should have a general solution for. I suspect this will come up many times
self._apply_lora_to_mlp, | ||
self._apply_lora_to_output, | ||
), | ||
"peft_type": "LORA", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about this, but if the base model used for training was loaded from HF in the HF format (i.e. a transformers PretrainedModel
), it should have a name_or_path
attribute. This could be stored and if it exists, we could add it to the config here as base_model_name_or_path
. This is not a required attribute for the adapter_config.json
but would be nice to have for a few situations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point. I was trying to avoid this initially since it may necessitate some changes to our load_checkpoint method, as right now we really only retrieve and remap model weights. If it's more of a nice-to-have, I may punt on it for this particular PR to keep things more isolated to save_checkpoint. Lmk if this makes sense. Also cc @kartikayk if you have any general thoughts on loading state/metadata through load_checkpointer
and passing through our recipe. I imagine this is something we may want to start supporting more for various integrations anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand a bit more on why we would need base_model_name_or_path
? Is this to make sure there are no bugs related to selecting the right base model for further training in HF land? If so, I wonder if this is something which is a "must have" rather than a "good to have"? or let me know if I misunderstand?
If it's a must have, then is this something we can read from one of the json files or do we need to pass this information along through the recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't strictly need base_model_name_or_path
, but not having it means that the burden is on the user to figure out which base model this adapter belongs to. Of course, this can be solved with good documentation, but having it automatically in the adapter_config.json
would be quite convenient.
Other points to consider:
- When shared on HF Hub, this metadata can be used for other things (I'm not an expert on this though)
- If
base_model_name_or_path
is present, users can load the adapter + base model in a single line of code (e.g.AutoModelForCausalLM.from_pretrained(<path-to-adapter>)
).
@@ -198,3 +198,78 @@ def _permute(t, n_heads): | |||
converted_state_dict[new_key] = value | |||
|
|||
return converted_state_dict | |||
|
|||
|
|||
_TO_PEFT_KEYS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe some quick comments on what these dicts refer to?
for k, v in _TO_PEFT_KEYS.items(): | ||
full_mapping.update( | ||
{ | ||
vv.replace(".weight", f".{k}.weight"): kk.replace( | ||
".weight", f".{v}.weight" | ||
) | ||
for kk, vv in _FROM_HF.items() | ||
if vv is not None | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block can use some comments explaining what's going on here
|
||
head_dim = dim // num_heads | ||
|
||
def _permute_lora_matrix(t, n_heads): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So these are permuted as well - nice find!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only B matrices though 😃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, this should hopefully be useful to many users. I don't have any critical comments, just a couple of smaller ones.
It might also be a good idea to add a test for each supported architecture, just to be sure that the re-mappings of the keys are the same for all of them.
self._apply_lora_to_mlp, | ||
self._apply_lora_to_output, | ||
), | ||
"peft_type": "LORA", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about this, but if the base model used for training was loaded from HF in the HF format (i.e. a transformers PretrainedModel
), it should have a name_or_path
attribute. This could be stored and if it exists, we could add it to the config here as base_model_name_or_path
. This is not a required attribute for the adapter_config.json
but would be nice to have for a few situations.
test_peft_integration.py
Outdated
|
||
|
||
if __name__ == "__main__": | ||
# test_permute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test still required? As is, it only prints something at the end, no asserts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah actually this whole test is probably going to be scrapped. As you point out in another comment, it is pretty expensive to run. Really I think I am gonna adopt a version of your other suggestion and will just add a unit test to confirm that key conversions etc are done correctly.
test_peft_integration.py
Outdated
with torch.no_grad(): | ||
peft_out = peft_model(inputs) | ||
tt_out = tt_model(inputs) | ||
print(f"Maximum difference: {torch.max(torch.abs(peft_out.logits - tt_out))}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be changed to an assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, will probably wind up removing this anyways. But in the forthcoming unit test I will use asserts
test_peft_integration.py
Outdated
|
||
# Initialize Llama2 and load merged checkpoint | ||
# (just testing that forward lines up) | ||
tt_model = llama2_7b() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could first create the outputs of the PEFT model, then delete, and then load the tune model to save memory. But probably not necessary as you probably have much beefier CI runners than us :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is prob the better way. I was taking advantage of the extra memory for debugging.. defining attributes on each model class for intermediate values then comparing each step along the way. But this is gonna get scrapped anyways
"lora_b": "lora_B", | ||
} | ||
|
||
_TO_PEFT_TARGET_MODULES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if a single mapping can be maintained for all supported architectures. I haven't actually tried if it works, but just checked the key names for the supported models and Phi3 seems to use gate_up_proj
(https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main?show_file_info=model-00001-of-00002.safetensors). So I wonder if one mapping per architecture is required (with this being the default mapping).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I've actually only tested for Llama2 so far, I think you're right that we'll need a separate mapping at least for Phi-3. We do have something here for the full checkpoint mapping already, will just need to adapt it for PEFT purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: there are other challenges with loading fine-tuned phi-3 checkpoints into PEFT from torchtune related to fused vs non-fused QKV. Namely, if someone fine-tunes in torchtune only on e.g. Q and K, they will not really be able to continue fine-tuning in PEFT in the way they would expect. In that case we can of course zero out the weights of the V chunk of the PEFT QKV LoRA matrix to get something that is in spirit correct, but (a) the user would probably expect only Q and K to remain trainable, which would not be the case, and (b) the learned LoRA weights from the torchtune finetune based on Q and K only may put any subsequent PEFT fine-tune using V as well in a suboptimal initial parameter space.
We could enforce up front that phi-3 LoRA is all-or-nothing on Q, K, and V for PEFT integration but I feel that's a bit messy. So for the time being I am opting to just raise a warning on checkpoint save that phi-3 adapter weights cannot be loaded into PEFT, and save just the usual torchtune adapter weights in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, yes I think giving a warning is the best solution in this situation.
The only issue I have with the warning is that it is only given during checkpointing. I would be afraid that a user starts an expensive training run only to find out the next day that the checkpoint was not saved as expected. Would it be possible to give the warning already at model initialization time?
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp | ||
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a single list is more intuitive, since, AFAICT, this is just consolidated into a single list under the hood.
or hopefully more descriptive names
Changing names later on can invalidate the saved checkpoints, so would require some versioning for backwards compatibility.
recipes/lora_finetune_distributed.py
Outdated
@@ -477,6 +477,19 @@ def save_checkpoint( | |||
# to be sent to the checkpointer and ultimately written to file | |||
if self._is_rank_zero: | |||
|
|||
# if training is in-progress, checkpoint the optimizer state and recipe state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb q: Why move this up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly to align with the ordering in the single-device recipe, but realistically it doesn't matter too much. Maybe I'll just leave as-is to not mix in extra complexity with the current set of changes. Also I realized somehow my changes to save PEFT config did not get pushed to this recipe, will update that now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good - thanks for persevering through all of the unknowns! A bunch of future-facing questions which would be good to think about.
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp | ||
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess versioning or some sort of a convertor/mapping? It would be great to figure this change out soon, but this point about checkpoint invalidation is a good one and something we should have a general solution for. I suspect this will come up many times
self._apply_lora_to_mlp, | ||
self._apply_lora_to_output, | ||
), | ||
"peft_type": "LORA", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand a bit more on why we would need base_model_name_or_path
? Is this to make sure there are no bugs related to selecting the right base model for further training in HF land? If so, I wonder if this is something which is a "must have" rather than a "good to have"? or let me know if I misunderstand?
If it's a must have, then is this something we can read from one of the json files or do we need to pass this information along through the recipe?
torchtune/models/convert_weights.py
Outdated
def tune_to_peft_adapter_config( | ||
adapter_config: Dict[str, Any], | ||
): | ||
expected_keys = ["target_modules", "r", "lora_alpha"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a constant like _TO_PEFT_TARGET_MODULES
?
return adapter_config | ||
|
||
|
||
def tune_to_peft_adapter_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan I'm curious what your thoughts are on this function. It seems like this (along with other similar conversion functions) are fairly brittle and susceptible to breakages resulting from changes in PEFT/Transformers. A couple of questions:
- How brittle is this in practice? Do we expect changes in these keys or permutation logic often?
- Are the unit tests enough to capture this? Do we need to add similar tests on the PEFT side as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- How brittle is this in practice? Do we expect changes in these keys or permutation logic often?
No, there shouldn't be any frequent changes in this regard, as that would result in incompatibilities of old HF checkpoints as well. Generally, when something changes in the modeling code, we try to preserve the format of the checkpoint and re-map while loading the state_dict
. I won't say it never happened in the past but I think it would generally be considered a bug and we'd fix it if notified.
- Are the unit tests enough to capture this? Do we need to add similar tests on the PEFT side as well?
This probably wouldn't hurt. I could imagine that if you push a converted checkpoint to the HF Hub (ideally a small model), we can add a test to check if we can load it successfully.
@@ -482,12 +484,57 @@ def save_checkpoint( | |||
f"{os.path.getsize(output_path) / 1000**3:.2f} GB " | |||
f"saved to {output_path}" | |||
) | |||
# Phi-3-mini uses fused QKV in PEFT, this will not work as expected | |||
# if only a proper subset of Q, K, V have been fine-tuned | |||
if self._model_type == ModelType.PHI3_MINI: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely getting very wieldy. As we discussed, probably there is opportunity to refactor checkpointers with a focus on adapters
PEFT integration
This is a PR for integration with PEFT. With this integration, you can take a fine-tuned checkpoint from torchtune and load it into a PEFT model using
from_pretrained
for continued fine-tuning or inference.First, finetune a model in torchtune. Using the tune CLI:
Then in Python:
And that's it! You can now use
peft_model
as you would any other PEFT model class.We automatically output the necessary files/formats for PEFT integration whenever using torchtune's HF checkpointer, so make sure to use that in your fine-tuning config if you want to load your torchtune checkpoints into PEFT (example config usage).
Implementation
We save a file
adapter_config.json
, along withadapter_model.bin
to match the format expected by PEFT. We also remap the LoRA weights to match the HF format (due to differences in RoPE implementations).The save logic differs depending on checkpointer and model. In summary:
Testing:
Unit tests
Added unit test in
test_checkpointer.py
to verify that adapter config and PEFT-compatible weights are saved as expectedRecipe tests
Manual E2E test
First create the file test_peft_integration.py as in this gist.
(1) ✅ Permute of LoRA weights works as expected (i.e.
_permute_lora_matrix(B) * A = _permute(B*A)
, which I think is what we want).(2) ✅ Uploaded adapter weights can be loaded into a transformers model via
from_pretrained
(3) ✅ Model forwards match within a reasonable tolerance across PEFT-loaded and torchtune-loaded checkpoints
For (3):
Test case 1: default config (Q and V only)
to save a fine-tuned LoRA checkpoint with adapter config and adapter weights in PEFT format. Then to compare forward pass when loading our fine-tuned checkpoint into PEFT vs into torchtune:
Test case 2: all layers, custom LoRA rank and alpha
Then