Skip to content

Commit

Permalink
fix: refactor adapter weight loading and mapping (#2193)
Browse files Browse the repository at this point in the history
* fix: refactor adapter weight loading and mapping

* feat: enable lora load from directory

* fix: adjust launcher for local lora adapters

* feat: improve weight loading and add tests

* fix: improve logging and rebase syntax issue

* fix: impove adapter merge comments and remove unused conditional

* fix: improve get_model_with_lora_adapters naming

* fix: comment typo
  • Loading branch information
drbh authored and ErikKaum committed Jul 26, 2024
1 parent 8ed01b1 commit ec40544
Show file tree
Hide file tree
Showing 13 changed files with 498 additions and 449 deletions.
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,10 @@ fn main() -> Result<(), LauncherError> {
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
// skip download if a path is provided
if adapter.contains('=') {
continue;
}
download_convert_model(
adapter,
None,
Expand Down
187 changes: 187 additions & 0 deletions server/tests/utils/test_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights


def test_get_attn_weights():
# create a mock layer
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

# call the function
result = get_attn_weights(2, mock_layer)

# assert the result
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected


def test_get_mlp_weights_with_gate_up_proj():
# create a mock layer with gate_up_proj
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

# call the function
result = get_mlp_weights(3, mock_layer)

# assert the result
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected


def test_get_mlp_weights_without_gate_up_proj():
# create a mock layer without gate_up_proj
mock_layer = Mock()
mock_layer.mlp = Mock(spec=[])

# call the function
result = get_mlp_weights(1, mock_layer)

# assert the result
assert result == {}


@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_attn_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

result = get_attn_weights(layer_index, mock_layer)

for k in ["q", "k", "v"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.self_attn.{k}_proj"
)

assert (layer_index, "o_proj") in result
assert (
result[(layer_index, "o_proj")][0]
== f"model.layers.{layer_index}.self_attn.o_proj"
)


@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_mlp_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

result = get_mlp_weights(layer_index, mock_layer)

for k in ["gate", "up", "down"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.mlp.{k}_proj"
)


def test_get_attn_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

result = get_attn_weights(2, mock_layer)

expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected


def test_get_mlp_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

result = get_mlp_weights(3, mock_layer)

expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected


def test_get_attn_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

result = get_attn_weights(2, mock_layer)

expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected


def test_get_mlp_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_proj = Mock()
mock_layer.mlp.up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
# This is necessary because the use of `Mock` automatically creates any
# attributes that are accessed, even if they don't exist in the actual
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
# follow the wrong execution path and return an incorrect result.
del mock_layer.mlp.gate_up_proj

result = get_mlp_weights(3, mock_layer)

expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
13 changes: 1 addition & 12 deletions server/text_generation_server/adapters/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Set, Tuple

import torch

Expand All @@ -31,14 +31,3 @@ def map_weights_for_model(
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass

@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass
60 changes: 26 additions & 34 deletions server/text_generation_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,6 @@ def map_weights_for_model(
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names

def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)

@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
Expand Down Expand Up @@ -192,22 +176,38 @@ def _transpose_weights(self):
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]

# prepare pre-loaded lora weights for use in the model.
#
# this method processes and organizes lora weights for a specific layer type across all layers:
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
# - retrieves weights from `module_map` based on the `layer_type`.
# - processes `nlayers` number of layers.
# - converts weights to the specified `dtype`.
# - shards weights across `world_size` number of processes using the `process_group`.
# - maps weights to specific layers using `target_to_layer`.
# - tracks `unused_weight_names` to identify any unused weights.
#
# the method handles weight transposition, scaling, and padding to ensure compatibility
# with SGMV or BGMV operations.
@classmethod
def load(
def prepare_weights(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
nlayers: int,
dtype: torch.dtype,
world_size: int,
process_group: ProcessGroup,
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers

for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device

Expand All @@ -216,10 +216,10 @@ def load(
return None

lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_a = lora_a.to(base_device, dtype)

lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
lora_b = lora_b.to(base_device, dtype)

scale = get_scaling_factor(
config.lora_alpha,
Expand All @@ -236,12 +236,8 @@ def load(
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]

if lora_a_list:
# update rank if it was padded
Expand All @@ -252,8 +248,8 @@ def load(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
process_group=process_group,
),
config,
)
Expand Down Expand Up @@ -293,10 +289,6 @@ def can_vectorize(self, pg: ProcessGroup) -> bool:
for rank_data in self.rank_data.values()
)

@classmethod
def key(cls) -> str:
return "lora"

@classmethod
def load(
self,
Expand Down
16 changes: 2 additions & 14 deletions server/text_generation_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
def has_adapter(self, adapter_index: int) -> bool:
pass

@abstractclassmethod
def key(cls) -> str:
pass

@abstractclassmethod
def load(
cls,
Expand All @@ -71,13 +67,6 @@ def remove_adapter(self, adapter_idx: int):
return
del self.adapter_weights[adapter_idx]

@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)

def is_empty(self) -> bool:
return len(self.adapter_weights) == 0

Expand All @@ -101,7 +90,7 @@ def get_data(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
batch_data = batched_weights
return batch_data


Expand Down Expand Up @@ -133,8 +122,7 @@ def from_meta(
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
for lora_data in self.data.values():
if lora_data is None:
continue

Expand Down
Loading

0 comments on commit ec40544

Please sign in to comment.