Skip to content

Commit f945195

Browse files
committed
Done
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 55a889a commit f945195

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

vllm/lora/layers.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@ def dec(*args, **kwargs):
7474
return dec
7575

7676

77-
def _is_hf_linear(layer: nn.Module, father_cls: type) -> bool:
78-
# Specify for `TransformersModel`
79-
return layer.__class__.__name__ == "HFCompatibleLinear" and isinstance(
80-
layer, father_cls)
81-
82-
8377
@dataclass
8478
class LoRAMapping(AdapterMapping):
8579
is_prefill: bool = False
@@ -407,6 +401,11 @@ def apply(self,
407401
self.output_slices)
408402
return output
409403

404+
@classmethod
405+
def get_source_layer(cls, source_layer: nn.Module) -> nn.Module:
406+
# Check parent_cls in case source_layer is a HFCompatibleLinear.
407+
return getattr(source_layer, "parent_cls", source_layer)
408+
410409

411410
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
412411

@@ -449,8 +448,8 @@ def can_replace_layer(
449448
packed_modules_list: List,
450449
model_config: Optional[PretrainedConfig],
451450
) -> bool:
452-
return type(source_layer) is ReplicatedLinear or _is_hf_linear(
453-
source_layer, ReplicatedLinear)
451+
source_layer = cls.get_source_layer(source_layer)
452+
return type(source_layer) is ReplicatedLinear
454453

455454

456455
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
@@ -546,10 +545,10 @@ def can_replace_layer(
546545
packed_modules_list: List,
547546
model_config: Optional[PretrainedConfig],
548547
) -> bool:
548+
source_layer = cls.get_source_layer(source_layer)
549549
return type(source_layer) is ColumnParallelLinear or (
550550
type(source_layer) is MergedColumnParallelLinear
551-
and len(packed_modules_list) == 1) or _is_hf_linear(
552-
source_layer, ColumnParallelLinear)
551+
and len(packed_modules_list) == 1)
553552

554553

555554
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@@ -690,6 +689,7 @@ def can_replace_layer(
690689
packed_modules_list: List,
691690
model_config: Optional[PretrainedConfig],
692691
) -> bool:
692+
source_layer = cls.get_source_layer(source_layer)
693693
return (type(source_layer) is MergedColumnParallelLinear
694694
and len(packed_modules_list) == 2)
695695

@@ -819,6 +819,7 @@ def can_replace_layer(
819819
packed_modules_list: List,
820820
model_config: Optional[PretrainedConfig],
821821
) -> bool:
822+
source_layer = cls.get_source_layer(source_layer)
822823
return (type(source_layer) is QKVParallelLinear
823824
and len(packed_modules_list) == 3)
824825

@@ -904,8 +905,8 @@ def can_replace_layer(
904905
packed_modules_list: List,
905906
model_config: Optional[PretrainedConfig],
906907
) -> bool:
907-
return type(source_layer) is RowParallelLinear or _is_hf_linear(
908-
source_layer, RowParallelLinear)
908+
source_layer = cls.get_source_layer(source_layer)
909+
return type(source_layer) is RowParallelLinear
909910

910911

911912
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):

vllm/model_executor/models/transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class HFCompatibleLinear(vllm_linear_cls):
101101
"""
102102
Wrapper class that removes `output_bias` from returned output.
103103
"""
104+
# NOTE: The LoRA layer needs to use `parent_cls`.
105+
@property
106+
def parent_cls(self):
107+
return vllm_linear_cls
104108

105109
def forward(self, input: torch.Tensor) -> torch.Tensor:
106110
return super().forward(input)[0]

0 commit comments

Comments
 (0)