Skip to content

Commit

Permalink
Merge pull request #5 from TUDB-Labs/autoclasses
Browse files Browse the repository at this point in the history
[feature] improve auto classes integrations with transformers
  • Loading branch information
mikecovlee authored Jul 30, 2024
2 parents e40cbdc + 027912c commit a8da949
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 204 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Test on Main

on: [push, pull_request]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip3 install -e .
pip3 install black isort flake8 pytest
- name: Code lint
run: |
black mixlora
isort mixlora --profile black
flake8 mixlora --show-source --statistics --max-line-length=128 --max-complexity 15 --ignore=E203,W503,E722
- name: Run tests
run: |
pytest
11 changes: 8 additions & 3 deletions mixlora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .config import MixLoraConfig
from .model import MixLoraModel, MixLoraSparseMoe
from .model import (
MixLoraModelForCausalLM,
inject_adapter_in_model,
load_adapter_weights,
)
from .prompter import Prompter
from .utils import is_package_available

Expand All @@ -10,7 +14,8 @@

__all__ = [
"MixLoraConfig",
"MixLoraModel",
"MixLoraSparseMoe",
"MixLoraModelForCausalLM",
"inject_adapter_in_model",
"load_adapter_weights",
"Prompter",
]
53 changes: 42 additions & 11 deletions mixlora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,45 @@

@dataclass
class AdapterConfig:
base_model_: str = None
task_type_: str = None
peft_type_: str = None
adapter_name_: str = None
hidden_size_: int = None
model_type_: str = None
dtype_: torch.dtype = None

@property
def base_model_name_or_path(self):
return self.base_model_

@property
def adapter_name(self):
return self.adapter_name_

def check(self) -> "AdapterConfig":
assert isinstance(self.base_model_, str)
assert isinstance(self.task_type_, str)
assert isinstance(self.peft_type_, str)

return self

@staticmethod
def from_config(config: Dict[str, any]) -> "AdapterConfig":
return AdapterConfig(
base_model_=config["base_model_name_or_path"],
task_type_=config["task_type"],
peft_type_=config["peft_type"],
)

def export(self) -> Dict[str, any]:
config = {}
config["bias"] = "none"
config["peft_type"] = self.peft_type_
config["task_type"] = self.task_type_
config["base_model_name_or_path"] = self.base_model_

return config


lora_target_modules = {
# LLaMA names
Expand Down Expand Up @@ -53,6 +87,7 @@ class LoraConfig(AdapterConfig):
target_modules_: Dict[str, bool] = None

def check(self) -> "LoraConfig":
super().check()
assert isinstance(self.use_dora_, bool)
assert isinstance(self.use_rslora_, bool)
assert isinstance(self.lora_init_, str) and self.lora_init_ in [
Expand All @@ -71,7 +106,7 @@ def check(self) -> "LoraConfig":

@staticmethod
def from_config(config: Dict[str, any]) -> "LoraConfig":
lora_config = LoraConfig()
lora_config = LoraConfig(**AdapterConfig.from_config(config).__dict__)
lora_config.use_dora_ = config.get("use_dora", False)
lora_config.use_rslora_ = config.get("use_rslora", False)
lora_config.lora_init_ = config.get("lora_init", "original")
Expand All @@ -93,13 +128,11 @@ def from_config(config: Dict[str, any]) -> "LoraConfig":
return lora_config

def export(self) -> Dict[str, any]:
config = {}
config = super().export()
if self.use_dora_:
config["use_dora"] = True
if self.use_rslora_:
config["use_rslora"] = True
config["bias"] = "none"
config["peft_type"] = "LORA"
config["r"] = self.lora_r_
config["lora_alpha"] = self.lora_alpha_
config["lora_dropout"] = self.lora_dropout_
Expand Down Expand Up @@ -161,11 +194,11 @@ def check(self) -> "MixLoraConfig":
@staticmethod
def from_config(config: Dict[str, any]) -> "MixLoraConfig":
lora_config = MixLoraConfig(**LoraConfig.from_config(config).__dict__)
lora_config.routing_strategy_ = config.get("routing_strategy", None)
assert (
"peft_type" in config
and config["peft_type"] == "MIXLORA"
and "routing_strategy" in config
and config["routing_strategy"] == "mixtral"
lora_config.peft_type_ == "MIXLORA"
and lora_config.routing_strategy_ is not None
and lora_config.routing_strategy_ == "mixtral"
), "MixLoraConfig only supports MixLoRA models with 'mixtral' routing_strategy."
if "expert_lora" in config:
expert_config = copy.deepcopy(config)
Expand All @@ -174,7 +207,6 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
lora_config.router_aux_loss_coef_ = config.get(
"router_aux_loss_coef", 0.001
) # for training
lora_config.routing_strategy_ = config["routing_strategy"]
lora_config.router_loss_ = config.get("router_loss", True)
lora_config.num_experts_ = config["num_experts"]
# silu for mixtral or gelu_new for switch transformers
Expand Down Expand Up @@ -213,5 +245,4 @@ def expert_config(self, expert_idx: int) -> LoraConfig:
config = copy.deepcopy(super())
else:
config = copy.deepcopy(self.expert_config_)
config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
return config
126 changes: 44 additions & 82 deletions mixlora/adapter.py → mixlora/lora_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

import torch
import torch.nn as nn
from transformers.utils import is_bitsandbytes_available

from .config import LoraConfig
from .utils import is_package_available

if is_bitsandbytes_available():
if is_package_available("bitsandbytes"):
import bitsandbytes as bnb
from bitsandbytes.nn import Linear4bit, Linear8bitLt
else:
from .utils import Linear8bitLt, Linear4bit

from typing import Dict, Optional, Tuple
from typing import Tuple


def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
Expand Down Expand Up @@ -73,35 +73,31 @@ def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter:
return weight


g_cached_range_tensor: Dict[torch.device, torch.Tensor] = {}
# also max batch size
g_max_range = 128


def get_range_tensor(device: torch.device, batch_size: int = 1024):
global g_cached_range_tensor
global g_max_range
if device not in g_cached_range_tensor or batch_size > g_max_range:
g_max_range = g_max_range if g_max_range > batch_size else batch_size
g_cached_range_tensor[device] = torch.arange(
0, g_max_range, step=1, device=device
)
return g_cached_range_tensor[device]


class Lora(nn.Module):
class LoraLinear(nn.Module):
def __init__(
self,
base_layer: nn.Module,
shape: Tuple[int, int],
config: LoraConfig,
device: str,
weight: Tuple[torch.Tensor, torch.Tensor] = (None, None),
device: str = None,
):

super().__init__()

if not isinstance(base_layer, nn.Linear):
assert isinstance(base_layer, Linear8bitLt) or isinstance(
base_layer, Linear4bit
), f"Unsupported base layer type '{type(base_layer)}'."

if isinstance(base_layer, Linear4bit):
out_dim, in_dim = (
base_layer.out_features,
base_layer.in_features,
)
else:
out_dim, in_dim = base_layer.weight.shape

self.base_layer_ = base_layer
self.device_ = torch.device(device)
self.device_ = torch.device(device) if device else base_layer.weight.device
self.dtype_ = config.dtype_

self.initializer_ = config.lora_init_
Expand All @@ -113,19 +109,20 @@ def __init__(
else:
self.scaling_ = self.alpha_ / self.r_

self.in_features_, self.out_features_ = shape
self.in_features_ = in_dim
self.out_features_ = out_dim

assert config.lora_dropout_ > 0.0
self.dropout_ = nn.Dropout(p=config.lora_dropout_)

self.lora_a_ = nn.Linear(
self.lora_A = nn.Linear(
self.in_features_,
self.r_,
bias=False,
dtype=self.dtype_,
device=self.device_,
)
self.lora_b_ = nn.Linear(
self.lora_B = nn.Linear(
self.r_,
self.out_features_,
bias=False,
Expand All @@ -136,35 +133,38 @@ def __init__(
self.use_dora_: bool = config.use_dora_
self.magnitude_vector_: nn.Parameter = None

self.reset_parameters(weight)

def _get_weight_norm(self) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = dequantize_module_weight(self.base_layer_).to(self.dtype_)
lora_weight = self.lora_b_.weight @ self.lora_a_.weight
lora_weight = self.lora_B.weight @ self.lora_A.weight
weight = weight + self.scaling_ * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm

def reset_parameters(self, lora_tensor=(None, None)) -> None:
def reset_parameters(
self, weight: Tuple[torch.Tensor, torch.Tensor] = (None, None)
) -> None:
# if the lora_tensor is not (None, None), use it to init the lora weight
assert isinstance(lora_tensor, Tuple)
assert len(lora_tensor) == 2
assert ((lora_tensor[0] is None) and (lora_tensor[1] is None)) or (
isinstance(lora_tensor[0], torch.Tensor)
and isinstance(lora_tensor[1], torch.Tensor)
assert isinstance(weight, Tuple)
assert len(weight) == 2
assert ((weight[0] is None) and (weight[1] is None)) or (
isinstance(weight[0], torch.Tensor) and isinstance(weight[1], torch.Tensor)
)

if lora_tensor == (None, None):
if weight == (None, None):
if self.initializer_ == "original":
nn.init.kaiming_uniform_(self.lora_a_.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
elif self.initializer_ == "gaussian":
nn.init.normal_(self.lora_a_.weight, std=1 / self.r_)
nn.init.normal_(self.lora_A.weight, std=1 / self.r_)
else:
raise ValueError(f"Unknown initialization {self.initializer_}")
nn.init.zeros_(self.lora_b_.weight)
nn.init.zeros_(self.lora_B.weight)
else:
with torch.no_grad():
self.lora_a_.weight.copy_(lora_tensor[0])
self.lora_b_.weight.copy_(lora_tensor[1])
self.lora_A.weight.copy_(weight[0])
self.lora_B.weight.copy_(weight[1])

if self.use_dora_:
self.magnitude_vector_ = nn.Parameter(
Expand All @@ -180,56 +180,18 @@ def apply_dora(
mag_norm_scale = (self.magnitude_vector_ / weight_norm).view(1, -1)
return mag_norm_scale * residual + mag_norm_scale * result_lora

def forward(
def lora_forward(
self, residual: torch.Tensor, hidden_states: torch.Tensor
) -> torch.Tensor:
result_lora = (
self.lora_b_(self.lora_a_(self.dropout_(hidden_states.to(self.dtype_))))
self.lora_B(self.lora_A(self.dropout_(hidden_states.to(self.dtype_))))
* self.scaling_
)
if self.use_dora_:
return self.apply_dora(residual, result_lora).to(hidden_states.dtype)
else:
return residual + result_lora.to(residual.dtype)


def init_lora_weight(
base_layer: nn.Module,
lora_config: LoraConfig,
lora_tensor=(None, None),
device: Optional[str] = None,
) -> Lora:
if not isinstance(base_layer, nn.Linear):
assert isinstance(base_layer, Linear8bitLt) or isinstance(
base_layer, Linear4bit
), f"Unsupported base layer type '{type(base_layer)}'."

if isinstance(base_layer, Linear4bit):
out_dim, in_dim = (
base_layer.out_features,
base_layer.in_features,
)
else:
out_dim, in_dim = base_layer.weight.shape

lora_layer = Lora(
base_layer,
(in_dim, out_dim),
lora_config,
device if device is not None else base_layer.weight.device,
)

lora_layer.reset_parameters(lora_tensor)

return lora_layer


class Linear(nn.Module):
def __init__(self, base_layer: nn.Module, lora_layer: Lora) -> None:
super().__init__()
self.base_layer_ = base_layer
self.lora_layer_ = lora_layer

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
result = self.base_layer_(hidden_states)
return self.lora_layer_(result, hidden_states)
residual = self.base_layer_(hidden_states)
return self.lora_forward(residual, hidden_states)
Loading

0 comments on commit a8da949

Please sign in to comment.