@@ -74,12 +74,6 @@ def dec(*args, **kwargs):
7474 return dec
7575
7676
77- def _is_hf_linear (layer : nn .Module , father_cls : type ) -> bool :
78- # Specify for `TransformersModel`
79- return layer .__class__ .__name__ == "HFCompatibleLinear" and isinstance (
80- layer , father_cls )
81-
82-
8377@dataclass
8478class LoRAMapping (AdapterMapping ):
8579 is_prefill : bool = False
@@ -407,6 +401,11 @@ def apply(self,
407401 self .output_slices )
408402 return output
409403
404+ @classmethod
405+ def get_source_layer (cls , source_layer : nn .Module ) -> nn .Module :
406+ # Check parent_cls in case source_layer is a HFCompatibleLinear.
407+ return getattr (source_layer , "parent_cls" , source_layer )
408+
410409
411410class ReplicatedLinearWithLoRA (BaseLinearLayerWithLoRA ):
412411
@@ -449,8 +448,8 @@ def can_replace_layer(
449448 packed_modules_list : List ,
450449 model_config : Optional [PretrainedConfig ],
451450 ) -> bool :
452- return type ( source_layer ) is ReplicatedLinear or _is_hf_linear (
453- source_layer , ReplicatedLinear )
451+ source_layer = cls . get_source_layer ( source_layer )
452+ return type ( source_layer ) is ReplicatedLinear
454453
455454
456455class ColumnParallelLinearWithLoRA (BaseLinearLayerWithLoRA ):
@@ -546,10 +545,10 @@ def can_replace_layer(
546545 packed_modules_list : List ,
547546 model_config : Optional [PretrainedConfig ],
548547 ) -> bool :
548+ source_layer = cls .get_source_layer (source_layer )
549549 return type (source_layer ) is ColumnParallelLinear or (
550550 type (source_layer ) is MergedColumnParallelLinear
551- and len (packed_modules_list ) == 1 ) or _is_hf_linear (
552- source_layer , ColumnParallelLinear )
551+ and len (packed_modules_list ) == 1 )
553552
554553
555554class MergedColumnParallelLinearWithLoRA (ColumnParallelLinearWithLoRA ):
@@ -690,6 +689,7 @@ def can_replace_layer(
690689 packed_modules_list : List ,
691690 model_config : Optional [PretrainedConfig ],
692691 ) -> bool :
692+ source_layer = cls .get_source_layer (source_layer )
693693 return (type (source_layer ) is MergedColumnParallelLinear
694694 and len (packed_modules_list ) == 2 )
695695
@@ -819,6 +819,7 @@ def can_replace_layer(
819819 packed_modules_list : List ,
820820 model_config : Optional [PretrainedConfig ],
821821 ) -> bool :
822+ source_layer = cls .get_source_layer (source_layer )
822823 return (type (source_layer ) is QKVParallelLinear
823824 and len (packed_modules_list ) == 3 )
824825
@@ -904,8 +905,8 @@ def can_replace_layer(
904905 packed_modules_list : List ,
905906 model_config : Optional [PretrainedConfig ],
906907 ) -> bool :
907- return type ( source_layer ) is RowParallelLinear or _is_hf_linear (
908- source_layer , RowParallelLinear )
908+ source_layer = cls . get_source_layer ( source_layer )
909+ return type ( source_layer ) is RowParallelLinear
909910
910911
911912class LogitsProcessorWithLoRA (BaseLayerWithLoRA ):
0 commit comments