Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Allow LoRA to adaptively increase rank and remove possible_max_ranks #10623

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,14 +1678,10 @@ class LoRAConfig:
bias_enabled: bool = False

def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast
# majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
if self.max_lora_rank < 1:
raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of "
f"{possible_max_ranks}.")
f"max_lora_rank ({self.max_lora_rank}) must be >= 1.")
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
raise ValueError(
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
Expand Down
146 changes: 137 additions & 9 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=unused-argument
import copy
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -130,6 +131,13 @@ class LoRAMapping(AdapterMapping):


class BaseLayerWithLoRA(nn.Module):
# Initialized following static typing.
_create_lora_weights_args: Tuple[int, LoRAConfig,
Optional[PretrainedConfig]] = (
0,
LoRAConfig(1, 1),
None,
)

def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
Expand All @@ -156,11 +164,18 @@ def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...

def update_max_lora_rank(
self,
lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]],
):
"""Updates max lora rank if larger lora matrices are given."""
...

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]],
lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]],
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
Expand Down Expand Up @@ -194,11 +209,14 @@ def __init__(self, base_layer: VocabParallelEmbedding) -> None:
self.embeddings_weights: Optional[torch.Tensor]

def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:

self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
if self.base_layer.num_added_embeddings_per_partition > 0:
# We can start adding lora weights
self.embeddings_weights = self.base_layer.weight.data[
Expand Down Expand Up @@ -255,6 +273,14 @@ def reset_lora(self, index: int):
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0

def update_max_lora_rank(
self,
lora_a: torch.Tensor,
):
if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank:
self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1]
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
Expand All @@ -264,6 +290,8 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
Expand Down Expand Up @@ -340,6 +368,9 @@ def create_lora_weights(
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
self.lora_config = lora_config
lora_a_output_size = lora_config.max_lora_rank
self.lora_a_stacked = torch.zeros(
Expand Down Expand Up @@ -375,6 +406,14 @@ def reset_lora(self, index: int):
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0

def update_max_lora_rank(
self,
lora_a: torch.Tensor,
):
if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank:
self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1]
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
Expand All @@ -384,6 +423,7 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
Expand Down Expand Up @@ -469,6 +509,9 @@ def create_lora_weights(
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
lora_a_output_size_per_partition = (
Expand Down Expand Up @@ -547,6 +590,21 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias = bias[start_idx:end_idx]
return bias

def update_max_lora_rank(
self,
lora_a: torch.Tensor,
) -> None:
if (self.lora_config.fully_sharded_loras
and lora_a.shape[1] * self.tp_size >
self.lora_config.max_lora_rank):
self._create_lora_weights_args[1].max_lora_rank = (
lora_a.shape[1] * self.tp_size)
self.create_lora_weights(*self._create_lora_weights_args)
elif (not self.lora_config.fully_sharded_loras
and lora_a.shape[1] > self.lora_config.max_lora_rank):
self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1]
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
Expand All @@ -556,6 +614,7 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
Expand Down Expand Up @@ -643,6 +702,9 @@ def create_lora_weights(
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
self.lora_config = lora_config
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
Expand Down Expand Up @@ -730,6 +792,23 @@ def slice_bias(
]
return bias

def update_max_lora_rank(self, lora_a: List[Union[torch.Tensor,
None]]) -> None:
for tensor in lora_a:
if tensor is None:
continue
if (self.lora_config.fully_sharded_loras
and tensor.shape[1] * self.tp_size >
self.lora_config.max_lora_rank):
self._create_lora_weights_args[1].max_lora_rank = (
tensor.shape[1] * self.tp_size)
self.create_lora_weights(*self._create_lora_weights_args)
elif (not self.lora_config.fully_sharded_loras
and tensor.shape[1] > self.lora_config.max_lora_rank):
self._create_lora_weights_args[1].max_lora_rank = (
tensor.shape[1])
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
Expand All @@ -739,6 +818,7 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
Expand Down Expand Up @@ -865,6 +945,8 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
Expand Down Expand Up @@ -911,6 +993,9 @@ def create_lora_weights(
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -1070,15 +1155,33 @@ def slice_bias(
bias = [bias_q, bias_k, bias_v]
return bias

def update_max_lora_rank(self, lora_a: List[Union[torch.Tensor,
None]]) -> None:
for tensor in lora_a:
if tensor is None:
continue
if (self.lora_config.fully_sharded_loras
and tensor.shape[1] * self.tp_size >
self.lora_config.max_lora_rank):
self._create_lora_weights_args[1].max_lora_rank = (
tensor.shape[1] * self.tp_size)
self.create_lora_weights(*self._create_lora_weights_args)
elif (not self.lora_config.fully_sharded_loras
and tensor.shape[1] > self.lora_config.max_lora_rank):
self._create_lora_weights_args[1].max_lora_rank = (
tensor.shape[1])
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: List[Union[torch.Tensor, None]],
lora_b: List[Union[torch.Tensor, None]],
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
Expand Down Expand Up @@ -1171,6 +1274,9 @@ def create_lora_weights(
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
self.lora_config = lora_config
self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros(
Expand Down Expand Up @@ -1235,6 +1341,14 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias

def update_max_lora_rank(
self,
lora_a: torch.Tensor,
):
if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank:
self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1]
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
Expand All @@ -1244,6 +1358,7 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

if self.base_layer.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
Expand Down Expand Up @@ -1399,6 +1514,9 @@ def create_lora_weights(
if 32000 < self.base_layer.vocab_size > 257024:
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 257024")
self._create_lora_weights_args = (max_loras,
copy.deepcopy(lora_config),
copy.deepcopy(model_config))
self.lora_a_stacked = torch.zeros(
(
max_loras,
Expand Down Expand Up @@ -1441,6 +1559,14 @@ def reset_lora(self, index: int):
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")

def update_max_lora_rank(
self,
lora_a: torch.Tensor,
):
if lora_a.shape[1] > self._create_lora_weights_args[1].max_lora_rank:
self._create_lora_weights_args[1].max_lora_rank = lora_a.shape[1]
self.create_lora_weights(*self._create_lora_weights_args)

def set_lora(
self,
index: int,
Expand All @@ -1450,6 +1576,8 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.update_max_lora_rank(lora_a)

self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
Expand Down
4 changes: 0 additions & 4 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
)
except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
Expand Down