Skip to content

Commit 01b47f9

Browse files
Isotr0pylulmer
authored andcommitted
[LoRA] Remove linear hack outside transformers backend (vllm-project#14177)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 69adf3b commit 01b47f9

File tree

4 files changed

+142
-105
lines changed

4 files changed

+142
-105
lines changed

vllm/lora/layers.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -395,17 +395,20 @@ def apply(self,
395395
x: torch.Tensor,
396396
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
397397
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
398+
399+
# In transformers backend, x and output have extra batch dimension like
400+
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
401+
# therefore we need to flatten the batch dimensions.
402+
if x.ndim == 3 and output.ndim == 3:
403+
output = output.flatten(0, 1)
404+
x = x.flatten(0, 1)
405+
398406
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
399407
self.lora_b_stacked,
400408
self.lora_bias_stacked, 1.0,
401409
self.output_slices)
402410
return output
403411

404-
@classmethod
405-
def get_source_layer(cls, source_layer: nn.Module) -> type:
406-
# Check parent_cls in case source_layer is a HFCompatibleLinear.
407-
return getattr(source_layer, "parent_cls", type(source_layer))
408-
409412

410413
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
411414

@@ -418,7 +421,7 @@ def __init__(self, base_layer: ReplicatedLinear) -> None:
418421

419422
def forward(
420423
self, input_: torch.Tensor
421-
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
424+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
422425
"""Forward of ReplicatedLinearWithLoRA
423426
424427
Args:
@@ -436,6 +439,10 @@ def forward(
436439

437440
output_bias = (self.base_layer.bias
438441
if self.base_layer.skip_bias_add else None)
442+
443+
if not self.base_layer.return_bias:
444+
return output
445+
439446
return output, output_bias
440447

441448
# ReplicatedLinear should always be replaced, regardless of the fully
@@ -448,8 +455,7 @@ def can_replace_layer(
448455
packed_modules_list: List,
449456
model_config: Optional[PretrainedConfig],
450457
) -> bool:
451-
source_layer = cls.get_source_layer(source_layer)
452-
return source_layer is ReplicatedLinear
458+
return type(source_layer) is ReplicatedLinear
453459

454460

455461
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
@@ -512,7 +518,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
512518

513519
def forward(
514520
self, input_: torch.Tensor
515-
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
521+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
516522
"""Forward of ColumnParallelLinear
517523
518524
Args:
@@ -532,6 +538,10 @@ def forward(
532538
output = tensor_model_parallel_all_gather(output_parallel)
533539
else:
534540
output = output_parallel
541+
542+
if not self.base_layer.return_bias:
543+
return output
544+
535545
output_bias = (self.base_layer.bias
536546
if self.base_layer.skip_bias_add else None)
537547
return output, output_bias
@@ -545,9 +555,8 @@ def can_replace_layer(
545555
packed_modules_list: List,
546556
model_config: Optional[PretrainedConfig],
547557
) -> bool:
548-
source_layer = cls.get_source_layer(source_layer)
549-
return source_layer is ColumnParallelLinear or (
550-
source_layer is MergedColumnParallelLinear
558+
return type(source_layer) is ColumnParallelLinear or (
559+
type(source_layer) is MergedColumnParallelLinear
551560
and len(packed_modules_list) == 1)
552561

553562

@@ -689,8 +698,7 @@ def can_replace_layer(
689698
packed_modules_list: List,
690699
model_config: Optional[PretrainedConfig],
691700
) -> bool:
692-
source_layer = cls.get_source_layer(source_layer)
693-
return (source_layer is MergedColumnParallelLinear
701+
return (type(source_layer) is MergedColumnParallelLinear
694702
and len(packed_modules_list) == 2)
695703

696704

@@ -758,8 +766,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
758766
def can_replace_layer(cls, source_layer: nn.Module,
759767
lora_config: LoRAConfig, packed_modules_list: List,
760768
model_config: Optional[PretrainedConfig]) -> bool:
761-
source_layer = cls.get_source_layer(source_layer)
762-
return source_layer is QKVParallelLinear and len(
769+
return type(source_layer) is QKVParallelLinear and len(
763770
packed_modules_list) == 1
764771

765772

@@ -820,8 +827,7 @@ def can_replace_layer(
820827
packed_modules_list: List,
821828
model_config: Optional[PretrainedConfig],
822829
) -> bool:
823-
source_layer = cls.get_source_layer(source_layer)
824-
return (source_layer is QKVParallelLinear
830+
return (type(source_layer) is QKVParallelLinear
825831
and len(packed_modules_list) == 3)
826832

827833

@@ -855,7 +861,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
855861

856862
def forward(
857863
self, input_: torch.Tensor
858-
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
864+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
859865
"""Forward of RowParallelLinear
860866
861867
Args:
@@ -890,6 +896,10 @@ def forward(
890896
else:
891897
output = output_
892898
output_bias = self.base_layer.bias
899+
900+
if not self.base_layer.return_bias:
901+
return output
902+
893903
return output, output_bias
894904

895905
@property
@@ -906,8 +916,7 @@ def can_replace_layer(
906916
packed_modules_list: List,
907917
model_config: Optional[PretrainedConfig],
908918
) -> bool:
909-
source_layer = cls.get_source_layer(source_layer)
910-
return source_layer is RowParallelLinear
919+
return type(source_layer) is RowParallelLinear
911920

912921

913922
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):

vllm/lora/utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,6 @@ def from_layer(layer: nn.Module,
6767
packed_modules_list=packed_modules_list,
6868
model_config=model_config):
6969
instance_layer = lora_cls(layer)
70-
if layer.__class__.__name__ == "HFCompatibleLinear":
71-
# HACK: Make the forward method compatible with the original
72-
# forward method of the instance_layer.
73-
original_forward = instance_layer.forward
74-
75-
def new_forward(input):
76-
input = input.squeeze(0)
77-
return original_forward(input)[0] # noqa: B023
78-
79-
instance_layer.forward = new_forward
8070
instance_layer.create_lora_weights(max_loras, lora_config,
8171
model_config)
8272
return instance_layer

0 commit comments

Comments
 (0)