diff --git a/README.md b/README.md index 36226ad..e454a56 100644 --- a/README.md +++ b/README.md @@ -137,13 +137,13 @@ Compared with LoRA, MixLoRA have some additional configurations. ``` This is an example of LoRA training configuration. -MixLoRA have two routing strategies: top-k routing (like *Mixtral*) and top-1 switch routing (like *Switch Transformers*), can be configured with `"routing_strategy": "mixtral"` or `"routing_strategy": "switch"`. +MixLoRA have two routing strategies: top-k routing (like *Mixtral*) and top-1 switch routing (like *Switch Transformers*), can be configured with `"routing_strategy": "mixlora"` or `"routing_strategy": "mixlora-switch"`. **Top-k Routing** ```json { ... - "routing_strategy": "mixtral", + "routing_strategy": "mixlora", "router_init_range": 0.02, "num_experts": 8, "top_k": 2, @@ -157,7 +157,7 @@ MixLoRA have two routing strategies: top-k routing (like *Mixtral*) and top-1 sw ```json { ... - "routing_strategy": "switch", + "routing_strategy": "mixlora-switch", "router_init_range": 0.02, "num_experts": 8, "expert_capacity": 32, diff --git a/mixlora/config.py b/mixlora/config.py index 598e254..c2c6e06 100644 --- a/mixlora/config.py +++ b/mixlora/config.py @@ -1,6 +1,6 @@ import copy from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch from transformers.activations import ACT2FN @@ -145,7 +145,7 @@ def export(self) -> Dict[str, any]: return config -available_routing_strategies = ["mixtral"] +available_routing_strategies = ["mixlora"] @dataclass @@ -159,8 +159,7 @@ class MixLoraConfig(LoraConfig): jitter_noise_: float = None router_loss_: bool = True num_experts_: int = None - act_fn_: Optional[str] = None - # mixtral config + act_fn_: Optional[Union[str, torch.nn.Module]] = None top_k_: int = None def check(self) -> "MixLoraConfig": @@ -184,7 +183,7 @@ def check(self) -> "MixLoraConfig": assert self.act_fn_ is None or ( isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN ) - if self.routing_strategy_ == "mixtral": + if self.routing_strategy_ == "mixlora": assert isinstance(self.top_k_, int) and self.top_k_ > 0 else: raise NotImplementedError() @@ -198,8 +197,8 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig": assert ( lora_config.peft_type_ == "MIXLORA" and lora_config.routing_strategy_ is not None - and lora_config.routing_strategy_ == "mixtral" - ), "MixLoraConfig only supports MixLoRA models with 'mixtral' routing_strategy." + and lora_config.routing_strategy_ == "mixlora" + ), "MixLoraConfig only supports MixLoRA models with 'mixlora' routing_strategy." if "expert_lora" in config: expert_config = copy.deepcopy(config) expert_config.update(config["expert_lora"]) @@ -209,10 +208,9 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig": ) # for training lora_config.router_loss_ = config.get("router_loss", True) lora_config.num_experts_ = config["num_experts"] - # silu for mixtral or gelu_new for switch transformers # left blank to automatically use the original act_fn of FFN lora_config.act_fn_ = config.get("act_fn", None) - if lora_config.routing_strategy_ == "mixtral": + if lora_config.routing_strategy_ == "mixlora": lora_config.router_init_range_ = config.get("router_init_range", 0.02) lora_config.jitter_noise_ = config.get("jitter_noise", 0.0) lora_config.top_k_ = config.get("top_k", 2) @@ -231,9 +229,9 @@ def export(self) -> Dict[str, any]: config["expert_lora"] = expert_config config["routing_strategy"] = self.routing_strategy_ config["num_experts"] = self.num_experts_ - if self.act_fn_ is not None: + if self.act_fn_ is not None and isinstance(self.act_fn_, str): config["act_fn"] = self.act_fn_ - if self.routing_strategy_ == "mixtral": + if self.routing_strategy_ == "mixlora": config["top_k"] = self.top_k_ else: raise NotImplementedError() diff --git a/mixlora/model.py b/mixlora/model.py index 1923ef3..6a9c394 100644 --- a/mixlora/model.py +++ b/mixlora/model.py @@ -51,7 +51,11 @@ def __init__( self.gate_: torch.Tensor = None self.base_layer_: torch.nn.Module = base_layer self.experts_: Dict[str, LoraLinear] = {} - self.act_fn_ = ACT2FN[config.act_fn_] + self.act_fn_ = ( + ACT2FN[config.act_fn_] + if isinstance(config.act_fn_, str) + else config.act_fn_ + ) self.num_experts_: int = config.num_experts_ self.topk_: int = config.top_k_ self.jitter_noise_: float = config.jitter_noise_ @@ -288,7 +292,7 @@ def _inject_mlp_module( weights: Dict[str, torch.Tensor], ): moe_layer = MixLoraSparseMoe(mlp, config) - moe_layer.gate_ = weights[f"mixlora.layers.{layer_idx}.gate.weight"].to( + moe_layer.gate_ = weights[f"mixlora.layers.{layer_idx}.mlp.moe_gate.weight"].to( config.dtype_ ) @@ -304,7 +308,7 @@ def _inject_mlp_module( base_layer = getattr(mlp, proj_name) for expert_idx in range(config.num_experts_): layer_prefix_name = ( - f"mixlora.layers.{layer_idx}.experts.{expert_idx}.{proj_name}" + f"mixlora.layers.{layer_idx}.mlp.{proj_name}.experts.{expert_idx}" ) moe_layer.experts_[f"experts.{expert_idx}.{proj_name}"] = LoraLinear( base_layer, diff --git a/pyproject.toml b/pyproject.toml index 71ce089..ca38bd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mixlora" -version = "0.2.0" +version = "0.2.1" description = "State-of-the-art Parameter-Efficient MoE Fine-tuning Method" readme = "README.md" requires-python = ">=3.8" diff --git a/tests/test_moe_layer.py b/tests/test_moe_layer.py index 4bcc39e..a88606e 100644 --- a/tests/test_moe_layer.py +++ b/tests/test_moe_layer.py @@ -23,7 +23,7 @@ def dummy_moe_layer( "lora_alpha": 16, "lora_dropout": 0.05, "target_modules": [], - "routing_strategy": "mixtral", + "routing_strategy": "mixlora", "num_experts": 8, "act_fn": "silu", "top_k": 2,