Skip to content

Commit 57977e0

Browse files
committed
add lora for mlp and unsloth
1 parent 6e0c9f6 commit 57977e0

File tree

8 files changed

+168
-12
lines changed

8 files changed

+168
-12
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ cmake_build_llama_runner
5555
# Constants.
5656
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
5757
PROMPT="What happens if you eat watermelon seeds?"
58-
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
58+
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C and"
5959

6060
# Export LoRA PTE file.
6161
MODEL_NAME="llama_3_2_1B_lora"

examples/models/llama/attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,17 @@ def __init__(
409409
)
410410
self.wo = (
411411
LoRALinear(
412-
in_dim=args.n_kv_heads * args.head_dim,
412+
in_dim=args.n_heads * args.head_dim,
413413
out_dim=args.dim,
414414
rank=args.r,
415415
alpha=args.lora_alpha,
416416
dropout=0.0,
417417
use_bias=args.attention_qkv_bias,
418418
)
419-
if args.target_modules is not None and "output_proj" in args.target_modules
419+
if args.target_modules is not None
420+
and (
421+
"output_proj" in args.target_modules or "o_proj" in args.target_modules
422+
)
420423
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
421424
)
422425

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import Dict
2+
3+
import torch
4+
5+
from safetensors.torch import load_file
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
_UNSLOTH_TO_META = {
9+
"base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight",
10+
"base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight",
11+
"base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight",
12+
"base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight",
13+
"base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight",
14+
"base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight",
15+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight",
16+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight",
17+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight",
18+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight",
19+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight",
20+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight",
21+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight",
22+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight",
23+
}
24+
25+
26+
def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
27+
"""
28+
Convert a state dict from unsloth format to Meta's format. This function
29+
doesn't handle any sharding or splitting of state dicts. It follows the
30+
state_dict IN -> state_dict OUT pattern.
31+
32+
Args:
33+
state_dict (Dict[str, torch.Tensor]): State dict in unsloth format.
34+
35+
Returns:
36+
Dict[str, torch.Tensor]: State dict in Meta's format.
37+
"""
38+
converted_state_dict = {}
39+
40+
for key, value in state_dict.items():
41+
try:
42+
new_key = get_mapped_key(key, _UNSLOTH_TO_META)
43+
except Exception as e:
44+
raise ValueError(f"Key {key} not found in mapping") from e
45+
46+
converted_state_dict[new_key] = value
47+
return converted_state_dict
48+
49+
50+
def load_and_convert_unsloth_to_meta(checkpoint_path: str) -> Dict[str, torch.Tensor]:
51+
"""
52+
Load a checkpoint file and convert it to Meta's format.
53+
54+
Args:
55+
checkpoint_path (str): Path to the checkpoint file.
56+
57+
Returns:
58+
Dict[str, torch.Tensor]: State dict in Meta's format.
59+
"""
60+
state_dict = load_file(checkpoint_path)
61+
return unsloth_to_meta(state_dict)

examples/models/llama/feed_forward.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import torch.nn.functional as F
2+
3+
from executorch.examples.models.llama.lora import LoRALinear
4+
from executorch.examples.models.llama.model_args import ModelArgs
25
from torch import nn
36

47

@@ -11,3 +14,55 @@ def __init__(self, dim: int, hidden_dim: int):
1114

1215
def forward(self, x):
1316
return self.w2(F.silu(self.w1(x)) * self.w3(x))
17+
18+
19+
class LoRAFeedForward(nn.Module):
20+
def __init__(self, dim: int, hidden_dim: int, args: ModelArgs):
21+
super().__init__()
22+
23+
if args.r is None or args.lora_alpha is None:
24+
raise ValueError(
25+
"LoRA rank and alpha must be specified for LoRAFeedForward."
26+
)
27+
28+
self.w1 = (
29+
LoRALinear(
30+
in_dim=dim,
31+
out_dim=hidden_dim,
32+
rank=args.r,
33+
alpha=args.lora_alpha,
34+
dropout=0.0,
35+
use_bias=False,
36+
)
37+
if "gate_proj" in args.target_modules
38+
else nn.Linear(dim, hidden_dim, bias=False)
39+
)
40+
41+
self.w2 = (
42+
LoRALinear(
43+
in_dim=hidden_dim,
44+
out_dim=dim,
45+
rank=args.r,
46+
alpha=args.lora_alpha,
47+
dropout=0.0,
48+
use_bias=False,
49+
)
50+
if "down_proj" in args.target_modules
51+
else nn.Linear(hidden_dim, dim, bias=False)
52+
)
53+
54+
self.w3 = (
55+
LoRALinear(
56+
in_dim=dim,
57+
out_dim=hidden_dim,
58+
rank=args.r,
59+
alpha=args.lora_alpha,
60+
dropout=0.0,
61+
use_bias=False,
62+
)
63+
if "up_proj" in args.target_modules
64+
else nn.Linear(dim, hidden_dim, bias=False)
65+
)
66+
67+
def forward(self, x):
68+
return self.w2(F.silu(self.w1(x)) * self.w3(x))

examples/models/llama/install_requirements.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
# Install safetensors to load safetensors checkpoints (currently adapter only).
14+
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile safetensors
1415

1516
# Call the install helper for further setup
1617
python examples/models/llama/install_requirement_helper.py

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
AttentionSkip,
1919
ForwardOptions,
2020
)
21-
from executorch.examples.models.llama.feed_forward import FeedForward
21+
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2222
from executorch.examples.models.llama.model_args import ModelArgs
2323
from executorch.examples.models.llama.norm import RMSNorm
2424
from executorch.examples.models.llama.rope import Rope
@@ -93,6 +93,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
9393
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
9494
if args.moe:
9595
self.block_sparse_moe = MOEFeedForward(args)
96+
elif args.target_modules is not None and (
97+
"down_proj" in args.target_modules
98+
or "up_proj" in args.target_modules
99+
or "gate_proj" in args.target_modules
100+
):
101+
self.feed_forward = LoRAFeedForward(args.dim, args.hidden_dim, args)
96102
else:
97103
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
98104

examples/models/llama/model.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18+
1819
from executorch.examples.models.llama.llama_transformer import construct_transformer
20+
from executorch.examples.models.llama.lora import LoRALinear
1921
from executorch.examples.models.llama.model_args import ModelArgs
2022
from executorch.examples.models.llama.rope import Rope
2123

@@ -140,14 +142,41 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
140142
adapter_checkpoint = {}
141143
adapter_config = {}
142144
if adapter_checkpoint_path:
143-
adapter_checkpoint = torch.load(
144-
adapter_checkpoint_path, map_location=device, mmap=True
145-
)
146-
from torchtune.models import convert_weights
145+
if adapter_checkpoint_path.endswith(".pt"):
146+
adapter_checkpoint = torch.load(
147+
adapter_checkpoint_path, map_location=device, mmap=True
148+
)
149+
from torchtune.models import convert_weights
150+
151+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
152+
elif adapter_checkpoint_path.endswith(".safetensors"):
153+
from executorch.examples.models.llama.convert_weights import (
154+
load_and_convert_unsloth_to_meta,
155+
)
156+
157+
adapter_checkpoint = load_and_convert_unsloth_to_meta(
158+
adapter_checkpoint_path
159+
)
160+
else:
161+
raise ValueError(
162+
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
163+
)
147164

148-
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
149165
with open(adapter_config_path, "r") as f:
150-
adapter_config = json.loads(f.read())
166+
adapter_config_full = json.loads(f.read())
167+
if (
168+
"r" not in adapter_config_full
169+
or "lora_alpha" not in adapter_config_full
170+
or "target_modules" not in adapter_config_full
171+
):
172+
raise ValueError(
173+
"Adapter config must contain r, lora_alpha, and target_modules."
174+
)
175+
adapter_config = {
176+
"r": adapter_config_full["r"],
177+
"lora_alpha": adapter_config_full["lora_alpha"],
178+
"target_modules": adapter_config_full["target_modules"],
179+
}
151180
checkpoint.update(adapter_checkpoint)
152181

153182
output_prune_map = None

examples/models/llama/model_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ class ModelArgs:
106106
# These arguments come directly from a torchtune adapter_config.json file.
107107
r: Optional[int] = None # Rank.
108108
lora_alpha: Optional[int] = None # Alpha.
109-
# Eg. q_proj, k_proj, v_proj, output_proj
109+
# Modules that we can apply lora adapters to.
110+
# Eg. q_proj, k_proj, v_proj, output_proj/o_proj, down_proj, gate_proj, up_proj
110111
target_modules: Optional[list] = None
111112
peft_type: Optional[str] = None # PEFT type.
112113
base_model_name_or_path: Optional[str] = None # Base model name or path.

0 commit comments

Comments
 (0)