Skip to content

Commit 777dbd2

Browse files
committed
add lora for mlp and unsloth
1 parent 45336ce commit 777dbd2

File tree

6 files changed

+131
-9
lines changed

6 files changed

+131
-9
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Dict
2+
3+
import torch
4+
5+
from torchtune.models.convert_weights import get_mapped_key
6+
7+
_UNSLOTH_TO_META = {
8+
"base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight",
9+
"base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight",
10+
"base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight",
11+
"base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight",
12+
"base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight",
13+
"base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight",
14+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight",
15+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight",
16+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight",
17+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight",
18+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight",
19+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight",
20+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight",
21+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight",
22+
}
23+
24+
25+
def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
26+
"""
27+
Convert a state dict from unsloth format to Meta's format. This function
28+
doesn't handle any sharding or splitting of state dicts. It follows the
29+
state_dict IN -> state_dict OUT pattern.
30+
31+
Args:
32+
state_dict (Dict[str, torch.Tensor]): State dict in unsloth format.
33+
34+
Returns:
35+
Dict[str, torch.Tensor]: State dict in Meta's format.
36+
"""
37+
converted_state_dict = {}
38+
39+
for key, value in state_dict.items():
40+
try:
41+
new_key = get_mapped_key(key, _UNSLOTH_TO_META)
42+
except:
43+
raise ValueError(f"Key {key} not found in mapping")
44+
45+
converted_state_dict[new_key] = value
46+
return converted_state_dict

examples/models/llama/feed_forward.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import torch.nn.functional as F
2+
3+
from executorch.examples.models.llama.lora import LoRALinear
4+
from executorch.examples.models.llama.model_args import ModelArgs
25
from torch import nn
36

47

@@ -11,3 +14,50 @@ def __init__(self, dim: int, hidden_dim: int):
1114

1215
def forward(self, x):
1316
return self.w2(F.silu(self.w1(x)) * self.w3(x))
17+
18+
19+
class LoRAFeedForward(nn.Module):
20+
def __init__(self, dim: int, hidden_dim: int, args: ModelArgs):
21+
super().__init__()
22+
23+
self.w1 = (
24+
LoRALinear(
25+
in_dim=dim,
26+
out_dim=hidden_dim,
27+
rank=args.r,
28+
alpha=args.lora_alpha,
29+
dropout=0.0,
30+
use_bias=False,
31+
)
32+
if "gate_proj" in args.target_modules
33+
else nn.Linear(dim, hidden_dim, bias=False)
34+
)
35+
36+
self.w2 = (
37+
LoRALinear(
38+
in_dim=hidden_dim,
39+
out_dim=dim,
40+
rank=args.r,
41+
alpha=args.lora_alpha,
42+
dropout=0.0,
43+
use_bias=False,
44+
)
45+
if "down_proj" in args.target_modules
46+
else nn.Linear(hidden_dim, dim, bias=False)
47+
)
48+
49+
self.w3 = (
50+
LoRALinear(
51+
in_dim=dim,
52+
out_dim=hidden_dim,
53+
rank=args.r,
54+
alpha=args.lora_alpha,
55+
dropout=0.0,
56+
use_bias=False,
57+
)
58+
if "up_proj" in args.target_modules
59+
else nn.Linear(dim, hidden_dim, bias=False)
60+
)
61+
62+
def forward(self, x):
63+
return self.w2(F.silu(self.w1(x)) * self.w3(x))

examples/models/llama/install_requirements.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
# Install safetensors to load safetensors checkpoints (currently adapter only).
14+
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile safetensors
1415

1516
# Call the install helper for further setup
1617
python examples/models/llama/install_requirement_helper.py

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
AttentionSkip,
1919
ForwardOptions,
2020
)
21-
from executorch.examples.models.llama.feed_forward import FeedForward
21+
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2222
from executorch.examples.models.llama.model_args import ModelArgs
2323
from executorch.examples.models.llama.norm import RMSNorm
2424
from executorch.examples.models.llama.rope import Rope
@@ -93,6 +93,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
9393
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
9494
if args.moe:
9595
self.block_sparse_moe = MOEFeedForward(args)
96+
elif args.target_modules is not None and (
97+
"down_proj" in args.target_modules
98+
or "up_proj" in args.target_modules
99+
or "gate_proj" in args.target_modules
100+
):
101+
self.feed_forward = LoRAFeedForward(args.dim, args.hidden_dim, args)
96102
else:
97103
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
98104

examples/models/llama/model.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18+
19+
from executorch.examples.models.llama.convert_weights import unsloth_to_meta
1820
from executorch.examples.models.llama.llama_transformer import construct_transformer
21+
from executorch.examples.models.llama.lora import LoRALinear
1922
from executorch.examples.models.llama.model_args import ModelArgs
2023
from executorch.examples.models.llama.rope import Rope
2124

@@ -140,14 +143,30 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
140143
adapter_checkpoint = {}
141144
adapter_config = {}
142145
if adapter_checkpoint_path:
143-
adapter_checkpoint = torch.load(
144-
adapter_checkpoint_path, map_location=device, mmap=True
145-
)
146-
from torchtune.models import convert_weights
146+
if adapter_checkpoint_path.endswith(".pt"):
147+
adapter_checkpoint = torch.load(
148+
adapter_checkpoint_path, map_location=device, mmap=True
149+
)
150+
from torchtune.models import convert_weights
151+
152+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
153+
elif adapter_checkpoint_path.endswith(".safetensors"):
154+
from safetensors.torch import load_file
155+
156+
adapter_checkpoint = load_file(adapter_checkpoint_path)
157+
adapter_checkpoint = unsloth_to_meta(adapter_checkpoint)
158+
else:
159+
raise ValueError(
160+
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
161+
)
147162

148-
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
149163
with open(adapter_config_path, "r") as f:
150-
adapter_config = json.loads(f.read())
164+
adapter_config_full = json.loads(f.read())
165+
adapter_config = {
166+
"r": adapter_config_full["r"],
167+
"lora_alpha": adapter_config_full["lora_alpha"],
168+
"target_modules": adapter_config_full["target_modules"],
169+
}
151170
checkpoint.update(adapter_checkpoint)
152171

153172
output_prune_map = None

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class ModelArgs:
106106
# These arguments come directly from a torchtune adapter_config.json file.
107107
r: Optional[int] = None # Rank.
108108
lora_alpha: Optional[int] = None # Alpha.
109-
# Eg. q_proj, k_proj, v_proj, output_proj
109+
# Eg. q_proj, k_proj, v_proj, output_proj/o_proj, down_proj, gate_proj, up_proj
110110
target_modules: Optional[list] = None
111111
peft_type: Optional[str] = None # PEFT type.
112112
base_model_name_or_path: Optional[str] = None # Base model name or path.

0 commit comments

Comments
 (0)