Skip to content

Commit

Permalink
[LoRA] ReplicatedLinear support LoRA (vllm-project#7081)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored and sfc-gh-mkeralapura committed Aug 12, 2024
1 parent 3e77ed4 commit 6ca373a
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 0 deletions.
103 changes: 103 additions & 0 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
Expand All @@ -31,6 +32,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand Down Expand Up @@ -545,6 +547,107 @@ def _pretest():
atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:

torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)

def create_random_linear_replicated_layer():

linear = ReplicatedLinear(4096,
4096,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ReplicatedLinearWithLoRA(linear)

lora_linear.create_lora_weights(max_loras, lora_config)

return linear, lora_linear

for i in range(10):
set_random_seed(i)

id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_linear,
layer_weights=linear.weight,
)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)

lora_result = lora_linear(torch.cat(inputs))[0]

expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)

rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)

# Check that resetting the lora weights succeeds

for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)

punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)

lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]

rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
Expand Down
94 changes: 94 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
Expand Down Expand Up @@ -262,6 +263,99 @@ def can_replace_layer(
return type(source_layer) is VocabParallelEmbedding


class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):

def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
self.output_size = self.base_layer.output_size
self.device = _get_lora_device(self.base_layer)

def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
lora_a_output_size = lora_config.max_lora_rank
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_a_output_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)

self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output

def forward(self, input_):
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)

# Matrix multiply.
output = self.apply(input_, bias)

output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
return output, output_bias

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear


class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
Expand Down
2 changes: 2 additions & 0 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
Expand All @@ -38,6 +39,7 @@
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
Expand Down

0 comments on commit 6ca373a

Please sign in to comment.