Skip to content

Commit

Permalink
fix dependency, add legacy weights convertor (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee authored Aug 14, 2024
1 parent 278c750 commit 89fa9d8
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 4 deletions.
94 changes: 94 additions & 0 deletions misc/legacy_weights_convertor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import os
from typing import Dict, List, Optional

import fire
import torch
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM

import mixlora

legacy_proj_names = {
"w1_proj": "gate_proj",
"w2_proj": "down_proj",
"w3_proj": "up_proj",
}

modern_proj_names = {
"gate_proj": "w1_proj",
"down_proj": "w2_proj",
"up_proj": "w3_proj",
}


def from_legacy(name_or_path: str, output_dir: Optional[str] = None):
if not os.path.exists(name_or_path):
assert output_dir is not None
name_or_path = snapshot_download(repo_id=name_or_path, repo_type="model")

if output_dir is None:
output_dir = name_or_path

if not os.path.exists(output_dir):
os.makedirs(output_dir)

with open(
name_or_path + os.sep + "adapter_config.json", "r", encoding="utf8"
) as fp:
config = json.load(fp)
assert "routing_strategy" in config and config["routing_strategy"] == "mixtral"
config["routing_strategy"] = "mixlora"
target_modules: List[str] = []
assert isinstance(config["target_modules"], List)
for target in config["target_modules"]:
if target in legacy_proj_names:
target = legacy_proj_names[target]
if target in mixlora.config.lora_target_modules:
target_modules.append(target)
config["target_modules"] = target_modules
config = mixlora.MixLoraConfig.from_config(config)

config.check()

weights: Dict[str, torch.Tensor] = torch.load(
name_or_path + os.sep + "adapter_model.bin", map_location="cpu"
)

model = AutoModelForCausalLM.from_pretrained(
config.base_model_,
torch_dtype=torch.float16,
device_map="cpu",
)

for layer_idx, layer in enumerate(model.model.layers):
weights[f"mixlora.layers.{layer_idx}.mlp.moe_gate.weight"] = weights.pop(
f"mixlora.layers.{layer_idx}.gate.weight"
)
for proj_name, inject in config.target_modules_.items():
if not inject or not hasattr(layer.mlp, proj_name):
continue
for expert_idx in range(config.num_experts_):
new_layer_prefix_name = (
f"mixlora.layers.{layer_idx}.mlp.{proj_name}.experts.{expert_idx}"
)
old_layer_prefix_name = (
f"mixlora.layers.{layer_idx}.experts.{expert_idx}.{proj_name}"
)
if f"{old_layer_prefix_name}.lora_A.weight" not in weights:
old_layer_prefix_name = f"mixlora.layers.{layer_idx}.experts.{expert_idx}.{modern_proj_names[proj_name]}"
weights[f"{new_layer_prefix_name}.lora_A.weight"] = weights.pop(
f"{old_layer_prefix_name}.lora_A.weight"
)
weights[f"{new_layer_prefix_name}.lora_B.weight"] = weights.pop(
f"{old_layer_prefix_name}.lora_B.weight"
)

torch.save(weights, output_dir + os.sep + "adapter_model.bin")

with open(output_dir + os.sep + "adapter_config.json", "w") as f:
json.dump(config.export(), f, indent=4)


if __name__ == "__main__":
fire.Fire(from_legacy)
4 changes: 3 additions & 1 deletion mixlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def load_adapter_weights(
config.adapter_name_ = adapter_name
config.dtype_ = dtype

weights = torch.load(
config.check()

weights: Dict[str, torch.Tensor] = torch.load(
name_or_path + os.sep + "adapter_model.bin", map_location=device
)

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mixlora"
version = "0.2.1"
version = "0.2.2"
description = "State-of-the-art Parameter-Efficient MoE Fine-tuning Method"
readme = "README.md"
requires-python = ">=3.8"
Expand All @@ -14,8 +14,8 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"torch>=2.3.0,<2.4.0",
"transformers>=4.43.0,<4.44.0",
"torch>=2.2.0",
"transformers>=4.43.0",
"huggingface_hub",
]

Expand Down

0 comments on commit 89fa9d8

Please sign in to comment.