Skip to content

Commit

Permalink
[feature] support mlora v0.5.0 (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee authored Aug 9, 2024
1 parent 88ff7cd commit 278c750
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 19 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ Compared with LoRA, MixLoRA have some additional configurations.
```
This is an example of LoRA training configuration.

MixLoRA have two routing strategies: top-k routing (like *Mixtral*) and top-1 switch routing (like *Switch Transformers*), can be configured with `"routing_strategy": "mixtral"` or `"routing_strategy": "switch"`.
MixLoRA have two routing strategies: top-k routing (like *Mixtral*) and top-1 switch routing (like *Switch Transformers*), can be configured with `"routing_strategy": "mixlora"` or `"routing_strategy": "mixlora-switch"`.

**Top-k Routing**
```json
{
...
"routing_strategy": "mixtral",
"routing_strategy": "mixlora",
"router_init_range": 0.02,
"num_experts": 8,
"top_k": 2,
Expand All @@ -157,7 +157,7 @@ MixLoRA have two routing strategies: top-k routing (like *Mixtral*) and top-1 sw
```json
{
...
"routing_strategy": "switch",
"routing_strategy": "mixlora-switch",
"router_init_range": 0.02,
"num_experts": 8,
"expert_capacity": 32,
Expand Down
20 changes: 9 additions & 11 deletions mixlora/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import torch
from transformers.activations import ACT2FN
Expand Down Expand Up @@ -145,7 +145,7 @@ def export(self) -> Dict[str, any]:
return config


available_routing_strategies = ["mixtral"]
available_routing_strategies = ["mixlora"]


@dataclass
Expand All @@ -159,8 +159,7 @@ class MixLoraConfig(LoraConfig):
jitter_noise_: float = None
router_loss_: bool = True
num_experts_: int = None
act_fn_: Optional[str] = None
# mixtral config
act_fn_: Optional[Union[str, torch.nn.Module]] = None
top_k_: int = None

def check(self) -> "MixLoraConfig":
Expand All @@ -184,7 +183,7 @@ def check(self) -> "MixLoraConfig":
assert self.act_fn_ is None or (
isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN
)
if self.routing_strategy_ == "mixtral":
if self.routing_strategy_ == "mixlora":
assert isinstance(self.top_k_, int) and self.top_k_ > 0
else:
raise NotImplementedError()
Expand All @@ -198,8 +197,8 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
assert (
lora_config.peft_type_ == "MIXLORA"
and lora_config.routing_strategy_ is not None
and lora_config.routing_strategy_ == "mixtral"
), "MixLoraConfig only supports MixLoRA models with 'mixtral' routing_strategy."
and lora_config.routing_strategy_ == "mixlora"
), "MixLoraConfig only supports MixLoRA models with 'mixlora' routing_strategy."
if "expert_lora" in config:
expert_config = copy.deepcopy(config)
expert_config.update(config["expert_lora"])
Expand All @@ -209,10 +208,9 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
) # for training
lora_config.router_loss_ = config.get("router_loss", True)
lora_config.num_experts_ = config["num_experts"]
# silu for mixtral or gelu_new for switch transformers
# left blank to automatically use the original act_fn of FFN
lora_config.act_fn_ = config.get("act_fn", None)
if lora_config.routing_strategy_ == "mixtral":
if lora_config.routing_strategy_ == "mixlora":
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
lora_config.top_k_ = config.get("top_k", 2)
Expand All @@ -231,9 +229,9 @@ def export(self) -> Dict[str, any]:
config["expert_lora"] = expert_config
config["routing_strategy"] = self.routing_strategy_
config["num_experts"] = self.num_experts_
if self.act_fn_ is not None:
if self.act_fn_ is not None and isinstance(self.act_fn_, str):
config["act_fn"] = self.act_fn_
if self.routing_strategy_ == "mixtral":
if self.routing_strategy_ == "mixlora":
config["top_k"] = self.top_k_
else:
raise NotImplementedError()
Expand Down
10 changes: 7 additions & 3 deletions mixlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def __init__(
self.gate_: torch.Tensor = None
self.base_layer_: torch.nn.Module = base_layer
self.experts_: Dict[str, LoraLinear] = {}
self.act_fn_ = ACT2FN[config.act_fn_]
self.act_fn_ = (
ACT2FN[config.act_fn_]
if isinstance(config.act_fn_, str)
else config.act_fn_
)
self.num_experts_: int = config.num_experts_
self.topk_: int = config.top_k_
self.jitter_noise_: float = config.jitter_noise_
Expand Down Expand Up @@ -288,7 +292,7 @@ def _inject_mlp_module(
weights: Dict[str, torch.Tensor],
):
moe_layer = MixLoraSparseMoe(mlp, config)
moe_layer.gate_ = weights[f"mixlora.layers.{layer_idx}.gate.weight"].to(
moe_layer.gate_ = weights[f"mixlora.layers.{layer_idx}.mlp.moe_gate.weight"].to(
config.dtype_
)

Expand All @@ -304,7 +308,7 @@ def _inject_mlp_module(
base_layer = getattr(mlp, proj_name)
for expert_idx in range(config.num_experts_):
layer_prefix_name = (
f"mixlora.layers.{layer_idx}.experts.{expert_idx}.{proj_name}"
f"mixlora.layers.{layer_idx}.mlp.{proj_name}.experts.{expert_idx}"
)
moe_layer.experts_[f"experts.{expert_idx}.{proj_name}"] = LoraLinear(
base_layer,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mixlora"
version = "0.2.0"
version = "0.2.1"
description = "State-of-the-art Parameter-Efficient MoE Fine-tuning Method"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dummy_moe_layer(
"lora_alpha": 16,
"lora_dropout": 0.05,
"target_modules": [],
"routing_strategy": "mixtral",
"routing_strategy": "mixlora",
"num_experts": 8,
"act_fn": "silu",
"top_k": 2,
Expand Down

0 comments on commit 278c750

Please sign in to comment.