-
Notifications
You must be signed in to change notification settings - Fork 507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save adapter config and remapped adapter weights for loading into PEFT #933
Changes from all commits
d179ff1
d617e6f
cc6e2c9
b5403d4
c802584
b41d523
3c2c130
14f71e1
73a8b42
ea4ef5b
2ebf748
7294638
9207699
d6a01ed
b35faf3
59ec6de
770aeba
9eb9b68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from torchtune.datasets import ConcatDataset | ||
from torchtune.modules.peft.peft_utils import ( | ||
get_adapter_params, | ||
get_lora_module_names, | ||
get_merged_lora_ckpt, | ||
set_trainable_params, | ||
validate_missing_and_unexpected_for_lora, | ||
|
@@ -258,6 +259,9 @@ def _setup_model( | |
|
||
self._lora_rank = cfg_model.lora_rank | ||
self._lora_alpha = cfg_model.lora_alpha | ||
self._lora_attn_modules = list(cfg_model.lora_attn_modules) | ||
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp | ||
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) | ||
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) | ||
|
||
|
@@ -275,11 +279,10 @@ def _setup_model( | |
) | ||
else: | ||
lora_missing, lora_unexpected = None, None | ||
|
||
validate_missing_and_unexpected_for_lora( | ||
lora_attn_modules=cfg_model.lora_attn_modules, | ||
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, | ||
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), | ||
lora_attn_modules=self._lora_attn_modules, | ||
apply_lora_to_mlp=self._apply_lora_to_mlp, | ||
apply_lora_to_output=self._apply_lora_to_output, | ||
base_missing=base_missing, | ||
base_unexpected=base_unexpected, | ||
lora_missing=lora_missing, | ||
|
@@ -417,6 +420,17 @@ def save_checkpoint(self, epoch: int) -> None: | |
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) | ||
} | ||
ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) | ||
adapter_config = { | ||
"r": self._lora_rank, | ||
"lora_alpha": self._lora_alpha, | ||
"target_modules": get_lora_module_names( | ||
self._lora_attn_modules, | ||
self._apply_lora_to_mlp, | ||
self._apply_lora_to_output, | ||
), | ||
"peft_type": "LORA", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about this, but if the base model used for training was loaded from HF in the HF format (i.e. a transformers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah good point. I was trying to avoid this initially since it may necessitate some changes to our load_checkpoint method, as right now we really only retrieve and remap model weights. If it's more of a nice-to-have, I may punt on it for this particular PR to keep things more isolated to save_checkpoint. Lmk if this makes sense. Also cc @kartikayk if you have any general thoughts on loading state/metadata through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you expand a bit more on why we would need If it's a must have, then is this something we can read from one of the json files or do we need to pass this information along through the recipe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't strictly need Other points to consider:
|
||
} | ||
ckpt_dict.update({utils.ADAPTER_CONFIG: adapter_config}) | ||
self._checkpointer.save_checkpoint( | ||
ckpt_dict, | ||
epoch=epoch, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
import re | ||
|
||
from typing import Dict | ||
from typing import Any, Dict | ||
|
||
import torch | ||
|
||
|
@@ -198,3 +198,85 @@ def _permute(t, n_heads): | |
converted_state_dict[new_key] = value | ||
|
||
return converted_state_dict | ||
|
||
|
||
# Mapping from torchtune LoRA module names to PEFT LoRA module names | ||
_TO_PEFT_KEYS = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe some quick comments on what these dicts refer to? |
||
"lora_a": "lora_A", | ||
"lora_b": "lora_B", | ||
} | ||
|
||
# Mapping from torchtune module names to target modules for PEFT adapter config | ||
_TO_PEFT_TARGET_MODULES = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if a single mapping can be maintained for all supported architectures. I haven't actually tried if it works, but just checked the key names for the supported models and Phi3 seems to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I've actually only tested for Llama2 so far, I think you're right that we'll need a separate mapping at least for Phi-3. We do have something here for the full checkpoint mapping already, will just need to adapt it for PEFT purposes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update: there are other challenges with loading fine-tuned phi-3 checkpoints into PEFT from torchtune related to fused vs non-fused QKV. Namely, if someone fine-tunes in torchtune only on e.g. Q and K, they will not really be able to continue fine-tuning in PEFT in the way they would expect. In that case we can of course zero out the weights of the V chunk of the PEFT QKV LoRA matrix to get something that is in spirit correct, but (a) the user would probably expect only Q and K to remain trainable, which would not be the case, and (b) the learned LoRA weights from the torchtune finetune based on Q and K only may put any subsequent PEFT fine-tune using V as well in a suboptimal initial parameter space. We could enforce up front that phi-3 LoRA is all-or-nothing on Q, K, and V for PEFT integration but I feel that's a bit messy. So for the time being I am opting to just raise a warning on checkpoint save that phi-3 adapter weights cannot be loaded into PEFT, and save just the usual torchtune adapter weights in that case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, yes I think giving a warning is the best solution in this situation. The only issue I have with the warning is that it is only given during checkpointing. I would be afraid that a user starts an expensive training run only to find out the next day that the checkpoint was not saved as expected. Would it be possible to give the warning already at model initialization time? |
||
"q_proj": "q_proj", | ||
"k_proj": "k_proj", | ||
"v_proj": "v_proj", | ||
"output_proj": "o_proj", | ||
"w1": "gate_proj", | ||
"w2": "down_proj", | ||
"w3": "up_proj", | ||
"output": "lm_head", | ||
} | ||
|
||
# Keys expected in PEFT's adapter_config.json | ||
_PEFT_CONFIG_EXPECTED_KEYS = ["target_modules", "r", "lora_alpha"] | ||
|
||
|
||
def tune_to_peft_adapter_config( | ||
adapter_config: Dict[str, Any], | ||
): | ||
if not all([x in adapter_config.keys() for x in _PEFT_CONFIG_EXPECTED_KEYS]): | ||
raise ValueError( | ||
f"PEFT adapter config requires {_PEFT_CONFIG_EXPECTED_KEYS}, found {adapter_config.keys()}" | ||
) | ||
|
||
for k in adapter_config["target_modules"]: | ||
if k not in _TO_PEFT_TARGET_MODULES: | ||
raise ValueError(f"Unknown target module {k}") | ||
adapter_config["target_modules"] = list( | ||
map(_TO_PEFT_TARGET_MODULES.get, adapter_config["target_modules"]) | ||
) | ||
|
||
return adapter_config | ||
|
||
|
||
def tune_to_peft_adapter_weights( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan I'm curious what your thoughts are on this function. It seems like this (along with other similar conversion functions) are fairly brittle and susceptible to breakages resulting from changes in PEFT/Transformers. A couple of questions:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, there shouldn't be any frequent changes in this regard, as that would result in incompatibilities of old HF checkpoints as well. Generally, when something changes in the modeling code, we try to preserve the format of the checkpoint and re-map while loading the
This probably wouldn't hurt. I could imagine that if you push a converted checkpoint to the HF Hub (ideally a small model), we can add a test to check if we can load it successfully. |
||
state_dict: Dict[str, torch.Tensor], | ||
num_heads: int = 32, | ||
num_kv_heads: int = 32, | ||
dim: int = 4096, | ||
): | ||
converted_state_dict = {} | ||
full_mapping = {} | ||
# Rather than recreate a separate mapping for LoRA adapter weights, we just | ||
# re-use the _FROM_HF mapping for base model weights. We iterate over it twice: | ||
# once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices. | ||
for k, v in _TO_PEFT_KEYS.items(): | ||
full_mapping.update( | ||
{ | ||
vv.replace(".weight", f".{k}.weight"): kk.replace( | ||
".weight", f".{v}.weight" | ||
) | ||
for kk, vv in _FROM_HF.items() | ||
if vv is not None | ||
} | ||
) | ||
Comment on lines
+254
to
+263
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block can use some comments explaining what's going on here |
||
|
||
head_dim = dim // num_heads | ||
|
||
def _permute_lora_matrix(t, n_heads): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So these are permuted as well - nice find! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only B matrices though 😃 |
||
rank = t.shape[-1] | ||
return ( | ||
t.view(n_heads, head_dim // 2, 2, rank) | ||
.transpose(1, 2) | ||
.reshape((head_dim * n_heads), rank) | ||
) | ||
|
||
for key, value in state_dict.items(): | ||
new_key = get_mapped_key(key, full_mapping) | ||
if "q_proj" in new_key and "lora_B" in new_key: | ||
value = _permute_lora_matrix(value, num_heads) | ||
elif "k_proj" in new_key and "lora_B" in new_key: | ||
value = _permute_lora_matrix(value, num_kv_heads) | ||
converted_state_dict["base_model.model." + new_key] = value | ||
return converted_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR, but maybe at some point we should consider replacing the
apply_lora_to_*
flags with just addingmlp
andoutput
to thelora_modules
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed, I think this is likely where we'll head eventually. One thing is that we will probably want to make LoRA in MLP more configurable (i.e. use
w1
,w2
,w3
(or hopefully more descriptive names) instead ofmlp
). Otherwise the relationship between e.g.q_proj
(nn.Linear) andmlp
(FeedForward) being in the same config is a bit confusing. Anyways this shouldn't be a huge effort to changeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a single list is more intuitive, since, AFAICT, this is just consolidated into a single list under the hood.
Changing names later on can invalidate the saved checkpoints, so would require some versioning for backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess versioning or some sort of a convertor/mapping? It would be great to figure this change out soon, but this point about checkpoint invalidation is a good one and something we should have a general solution for. I suspect this will come up many times