Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save adapter config and remapped adapter weights for loading into PEFT #933

Merged
merged 18 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
Expand Down Expand Up @@ -278,6 +279,12 @@ def _setup_model(
the correct device.
"""

self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha
self._lora_attn_modules = list(cfg_model.lora_attn_modules)
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)

if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
init_start = time.perf_counter()
Expand Down Expand Up @@ -510,6 +517,18 @@ def save_checkpoint(
}
)

adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}
checkpoint_dict.update({utils.ADAPTER_CONFIG: adapter_config})

self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
Expand Down
22 changes: 18 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -258,6 +259,9 @@ def _setup_model(

self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha
self._lora_attn_modules = list(cfg_model.lora_attn_modules)
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
Comment on lines +263 to +264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but maybe at some point we should consider replacing the apply_lora_to_* flags with just adding mlp and output to the lora_modules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, I think this is likely where we'll head eventually. One thing is that we will probably want to make LoRA in MLP more configurable (i.e. use w1, w2, w3 (or hopefully more descriptive names) instead of mlp). Otherwise the relationship between e.g. q_proj (nn.Linear) and mlp (FeedForward) being in the same config is a bit confusing. Anyways this shouldn't be a huge effort to change

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a single list is more intuitive, since, AFAICT, this is just consolidated into a single list under the hood.

or hopefully more descriptive names

Changing names later on can invalidate the saved checkpoints, so would require some versioning for backwards compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess versioning or some sort of a convertor/mapping? It would be great to figure this change out soon, but this point about checkpoint invalidation is a good one and something we should have a general solution for. I suspect this will come up many times

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

Expand All @@ -275,11 +279,10 @@ def _setup_model(
)
else:
lora_missing, lora_unexpected = None, None

validate_missing_and_unexpected_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
Expand Down Expand Up @@ -417,6 +420,17 @@ def save_checkpoint(self, epoch: int) -> None:
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this, but if the base model used for training was loaded from HF in the HF format (i.e. a transformers PretrainedModel), it should have a name_or_path attribute. This could be stored and if it exists, we could add it to the config here as base_model_name_or_path. This is not a required attribute for the adapter_config.json but would be nice to have for a few situations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point. I was trying to avoid this initially since it may necessitate some changes to our load_checkpoint method, as right now we really only retrieve and remap model weights. If it's more of a nice-to-have, I may punt on it for this particular PR to keep things more isolated to save_checkpoint. Lmk if this makes sense. Also cc @kartikayk if you have any general thoughts on loading state/metadata through load_checkpointer and passing through our recipe. I imagine this is something we may want to start supporting more for various integrations anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand a bit more on why we would need base_model_name_or_path? Is this to make sure there are no bugs related to selecting the right base model for further training in HF land? If so, I wonder if this is something which is a "must have" rather than a "good to have"? or let me know if I misunderstand?

If it's a must have, then is this something we can read from one of the json files or do we need to pass this information along through the recipe?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't strictly need base_model_name_or_path, but not having it means that the burden is on the user to figure out which base model this adapter belongs to. Of course, this can be solved with good documentation, but having it automatically in the adapter_config.json would be quite convenient.

Other points to consider:

  • When shared on HF Hub, this metadata can be used for other things (I'm not an expert on this though)
  • If base_model_name_or_path is present, users can load the adapter + base model in a single line of code (e.g. AutoModelForCausalLM.from_pretrained(<path-to-adapter>)).

}
ckpt_dict.update({utils.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
Expand Down
144 changes: 144 additions & 0 deletions tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
from torch import randn

from torchtune.models import llama2, mistral
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_lora_module_names,
validate_missing_and_unexpected_for_lora,
)
from torchtune.utils._checkpointing import FullModelHFCheckpointer
from torchtune.utils._checkpointing._checkpointer_utils import safe_torch_load
from torchtune.utils.constants import ADAPTER_CONFIG, ADAPTER_KEY
from torchtune.utils.seed import set_seed

_VOCAB_SIZE = 100
Expand Down Expand Up @@ -293,6 +299,144 @@ def test_save_load_checkpoint_multiple_file(
assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys())
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys())

def test_save_checkpoint_in_peft_format(
self,
single_file_checkpointer: FullModelHFCheckpointer,
llama2_hf_checkpoints: Tuple[Path, Path],
):
"""
Test save_checkpoint method within the FullModelCheckpointer for
integration with HF PEFT (i.e. save_in_peft_format=True).

We test that:
* The file adapter_config.json contains the fields required by PEFT
and the correct values
* The state dict keys of the saved adapter checkpoint are remapped as expected
* The state dict values of the saved adapter checkpoint (after key remapping)
match those in torchtune for parameters that are not permuted by HF
# The state dict values of the saved adapter checkpoint (after key remapping)
do not match those in torchtune for parameters that are permuted by HF, but the
sums along the dimension of permutation match
"""

# Define LoRA params for this test
lora_attn_modules = ["q_proj", "output_proj"]
apply_lora_to_mlp = True
apply_lora_to_output = True
lora_rank = 4
lora_alpha = 8

checkpoint_file, _ = llama2_hf_checkpoints
state_dict = single_file_checkpointer.load_checkpoint()

# Build LoRA Llama2 model and load in base model weights
model = llama2.lora_llama2(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=_VOCAB_SIZE,
num_layers=1,
num_heads=_NUM_HEADS,
num_kv_heads=_NUM_KV_HEADS,
embed_dim=_DIM,
max_seq_len=128,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
)
missing, unexpected = model.load_state_dict(state_dict["model"], strict=False)
validate_missing_and_unexpected_for_lora(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
base_missing=missing,
base_unexpected=unexpected,
)

# LoRA B params are zero-initialized, randomly initialize them to make
# the test of their permutation on checkpoint save nontrivial
lora_b_sd = {
k: torch.randn_like(v)
for k, v in model.state_dict().items()
if "lora_b" in k
}
model.load_state_dict(lora_b_sd, strict=False)

# Construct the adapter weights and config and save using checkpointer
adapter_params = get_adapter_params(model)
adapter_key_filter = lambda x: x in adapter_params
expected_adapter_state_dict = {
k: v for k, v in model.state_dict().items() if adapter_key_filter(k)
}
adapter_config = {
"r": lora_rank,
"lora_alpha": lora_alpha,
"target_modules": get_lora_module_names(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
),
"peft_type": "LORA",
}
state_dict.update({ADAPTER_KEY: expected_adapter_state_dict})
state_dict.update({ADAPTER_CONFIG: adapter_config})
single_file_checkpointer.save_checkpoint(state_dict, epoch=1)

# Load saved adapter weights and config from file for comparison
adapter_weights_file = Path.joinpath(
checkpoint_file.parent, "adapter_model.bin"
)
actual_adapter_state_dict = safe_torch_load(adapter_weights_file)

adapter_config_file = Path.joinpath(
checkpoint_file.parent, "adapter_config.json"
)
with open(adapter_config_file, "r") as f:
adapter_config = json.load(f)

expected_target_modules = [
"down_proj",
"gate_proj",
"lm_head",
"o_proj",
"q_proj",
"up_proj",
]
assert sorted(adapter_config["target_modules"]) == expected_target_modules

# Map PEFT keys back to torchtune keys
peft_to_tt = {
"o_proj": "output_proj",
"gate_proj": "w1",
"down_proj": "w2",
"up_proj": "w3",
"lm_head": "output",
}
for k, v in actual_adapter_state_dict.items():
new_k = k.replace("base_model.model.", "").replace("self_attn", "attn")
if "lm_head" not in new_k:
new_k = new_k.replace("model.", "")
for kk, vv in peft_to_tt.items():
if kk in k:
new_k = new_k.replace(kk, vv)
new_k = new_k.replace("lora_A", "lora_a").replace("lora_B", "lora_b")

# LoRA B matrix for Q should not match due to Q and K permutation
# However, since they're permuted along embed dim, their sum along that axis should match
if "lora_b" in new_k and "q_proj" in new_k:
assert not torch.allclose(
actual_adapter_state_dict[k], expected_adapter_state_dict[new_k]
)
torch.testing.assert_close(
actual_adapter_state_dict[k].sum(dim=0),
expected_adapter_state_dict[new_k].sum(dim=0),
)

# All other matrices should match exactly
if "lora_b" not in new_k:
torch.testing.assert_close(
actual_adapter_state_dict[k], expected_adapter_state_dict[new_k]
)


class TestHFMistralRewardModelFullModelCheckpointer:
@pytest.fixture
Expand Down
84 changes: 83 additions & 1 deletion torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re

from typing import Dict
from typing import Any, Dict

import torch

Expand Down Expand Up @@ -198,3 +198,85 @@ def _permute(t, n_heads):
converted_state_dict[new_key] = value

return converted_state_dict


# Mapping from torchtune LoRA module names to PEFT LoRA module names
_TO_PEFT_KEYS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some quick comments on what these dicts refer to?

"lora_a": "lora_A",
"lora_b": "lora_B",
}

# Mapping from torchtune module names to target modules for PEFT adapter config
_TO_PEFT_TARGET_MODULES = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if a single mapping can be maintained for all supported architectures. I haven't actually tried if it works, but just checked the key names for the supported models and Phi3 seems to use gate_up_proj (https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main?show_file_info=model-00001-of-00002.safetensors). So I wonder if one mapping per architecture is required (with this being the default mapping).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've actually only tested for Llama2 so far, I think you're right that we'll need a separate mapping at least for Phi-3. We do have something here for the full checkpoint mapping already, will just need to adapt it for PEFT purposes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: there are other challenges with loading fine-tuned phi-3 checkpoints into PEFT from torchtune related to fused vs non-fused QKV. Namely, if someone fine-tunes in torchtune only on e.g. Q and K, they will not really be able to continue fine-tuning in PEFT in the way they would expect. In that case we can of course zero out the weights of the V chunk of the PEFT QKV LoRA matrix to get something that is in spirit correct, but (a) the user would probably expect only Q and K to remain trainable, which would not be the case, and (b) the learned LoRA weights from the torchtune finetune based on Q and K only may put any subsequent PEFT fine-tune using V as well in a suboptimal initial parameter space.

We could enforce up front that phi-3 LoRA is all-or-nothing on Q, K, and V for PEFT integration but I feel that's a bit messy. So for the time being I am opting to just raise a warning on checkpoint save that phi-3 adapter weights cannot be loaded into PEFT, and save just the usual torchtune adapter weights in that case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes I think giving a warning is the best solution in this situation.

The only issue I have with the warning is that it is only given during checkpointing. I would be afraid that a user starts an expensive training run only to find out the next day that the checkpoint was not saved as expected. Would it be possible to give the warning already at model initialization time?

"q_proj": "q_proj",
"k_proj": "k_proj",
"v_proj": "v_proj",
"output_proj": "o_proj",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"output": "lm_head",
}

# Keys expected in PEFT's adapter_config.json
_PEFT_CONFIG_EXPECTED_KEYS = ["target_modules", "r", "lora_alpha"]


def tune_to_peft_adapter_config(
adapter_config: Dict[str, Any],
):
if not all([x in adapter_config.keys() for x in _PEFT_CONFIG_EXPECTED_KEYS]):
raise ValueError(
f"PEFT adapter config requires {_PEFT_CONFIG_EXPECTED_KEYS}, found {adapter_config.keys()}"
)

for k in adapter_config["target_modules"]:
if k not in _TO_PEFT_TARGET_MODULES:
raise ValueError(f"Unknown target module {k}")
adapter_config["target_modules"] = list(
map(_TO_PEFT_TARGET_MODULES.get, adapter_config["target_modules"])
)

return adapter_config


def tune_to_peft_adapter_weights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan I'm curious what your thoughts are on this function. It seems like this (along with other similar conversion functions) are fairly brittle and susceptible to breakages resulting from changes in PEFT/Transformers. A couple of questions:

  • How brittle is this in practice? Do we expect changes in these keys or permutation logic often?
  • Are the unit tests enough to capture this? Do we need to add similar tests on the PEFT side as well?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • How brittle is this in practice? Do we expect changes in these keys or permutation logic often?

No, there shouldn't be any frequent changes in this regard, as that would result in incompatibilities of old HF checkpoints as well. Generally, when something changes in the modeling code, we try to preserve the format of the checkpoint and re-map while loading the state_dict. I won't say it never happened in the past but I think it would generally be considered a bug and we'd fix it if notified.

  • Are the unit tests enough to capture this? Do we need to add similar tests on the PEFT side as well?

This probably wouldn't hurt. I could imagine that if you push a converted checkpoint to the HF Hub (ideally a small model), we can add a test to check if we can load it successfully.

state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
):
converted_state_dict = {}
full_mapping = {}
# Rather than recreate a separate mapping for LoRA adapter weights, we just
# re-use the _FROM_HF mapping for base model weights. We iterate over it twice:
# once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices.
for k, v in _TO_PEFT_KEYS.items():
full_mapping.update(
{
vv.replace(".weight", f".{k}.weight"): kk.replace(
".weight", f".{v}.weight"
)
for kk, vv in _FROM_HF.items()
if vv is not None
}
)
Comment on lines +254 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block can use some comments explaining what's going on here


head_dim = dim // num_heads

def _permute_lora_matrix(t, n_heads):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So these are permuted as well - nice find!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only B matrices though 😃

rank = t.shape[-1]
return (
t.view(n_heads, head_dim // 2, 2, rank)
.transpose(1, 2)
.reshape((head_dim * n_heads), rank)
)

for key, value in state_dict.items():
new_key = get_mapped_key(key, full_mapping)
if "q_proj" in new_key and "lora_B" in new_key:
value = _permute_lora_matrix(value, num_heads)
elif "k_proj" in new_key and "lora_B" in new_key:
value = _permute_lora_matrix(value, num_kv_heads)
converted_state_dict["base_model.model." + new_key] = value
return converted_state_dict
1 change: 1 addition & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .argparse import TuneRecipeArgumentParser
from .collate import padded_collate, padded_collate_dpo
from .constants import ( # noqa
ADAPTER_CONFIG,
ADAPTER_KEY,
EPOCHS_KEY,
MAX_STEPS_KEY,
Expand Down
Loading
Loading