@@ -395,17 +395,20 @@ def apply(self,
395395 x : torch .Tensor ,
396396 bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
397397 output = self .base_layer .quant_method .apply (self .base_layer , x , bias )
398+
399+ # In transformers backend, x and output have extra batch dimension like
400+ # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
401+ # therefore we need to flatten the batch dimensions.
402+ if x .ndim == 3 and output .ndim == 3 :
403+ output = output .flatten (0 , 1 )
404+ x = x .flatten (0 , 1 )
405+
398406 self .punica_wrapper .add_lora_linear (output , x , self .lora_a_stacked ,
399407 self .lora_b_stacked ,
400408 self .lora_bias_stacked , 1.0 ,
401409 self .output_slices )
402410 return output
403411
404- @classmethod
405- def get_source_layer (cls , source_layer : nn .Module ) -> type :
406- # Check parent_cls in case source_layer is a HFCompatibleLinear.
407- return getattr (source_layer , "parent_cls" , type (source_layer ))
408-
409412
410413class ReplicatedLinearWithLoRA (BaseLinearLayerWithLoRA ):
411414
@@ -418,7 +421,7 @@ def __init__(self, base_layer: ReplicatedLinear) -> None:
418421
419422 def forward (
420423 self , input_ : torch .Tensor
421- ) -> Tuple [ Optional [torch .Tensor ] , Optional [torch .Tensor ]]:
424+ ) -> Union [ torch . Tensor , tuple [torch .Tensor , Optional [torch .Tensor ] ]]:
422425 """Forward of ReplicatedLinearWithLoRA
423426
424427 Args:
@@ -436,6 +439,10 @@ def forward(
436439
437440 output_bias = (self .base_layer .bias
438441 if self .base_layer .skip_bias_add else None )
442+
443+ if not self .base_layer .return_bias :
444+ return output
445+
439446 return output , output_bias
440447
441448 # ReplicatedLinear should always be replaced, regardless of the fully
@@ -448,8 +455,7 @@ def can_replace_layer(
448455 packed_modules_list : List ,
449456 model_config : Optional [PretrainedConfig ],
450457 ) -> bool :
451- source_layer = cls .get_source_layer (source_layer )
452- return source_layer is ReplicatedLinear
458+ return type (source_layer ) is ReplicatedLinear
453459
454460
455461class ColumnParallelLinearWithLoRA (BaseLinearLayerWithLoRA ):
@@ -512,7 +518,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
512518
513519 def forward (
514520 self , input_ : torch .Tensor
515- ) -> Tuple [ Optional [torch .Tensor ] , Optional [torch .Tensor ]]:
521+ ) -> Union [ torch . Tensor , tuple [torch .Tensor , Optional [torch .Tensor ] ]]:
516522 """Forward of ColumnParallelLinear
517523
518524 Args:
@@ -532,6 +538,10 @@ def forward(
532538 output = tensor_model_parallel_all_gather (output_parallel )
533539 else :
534540 output = output_parallel
541+
542+ if not self .base_layer .return_bias :
543+ return output
544+
535545 output_bias = (self .base_layer .bias
536546 if self .base_layer .skip_bias_add else None )
537547 return output , output_bias
@@ -545,9 +555,8 @@ def can_replace_layer(
545555 packed_modules_list : List ,
546556 model_config : Optional [PretrainedConfig ],
547557 ) -> bool :
548- source_layer = cls .get_source_layer (source_layer )
549- return source_layer is ColumnParallelLinear or (
550- source_layer is MergedColumnParallelLinear
558+ return type (source_layer ) is ColumnParallelLinear or (
559+ type (source_layer ) is MergedColumnParallelLinear
551560 and len (packed_modules_list ) == 1 )
552561
553562
@@ -689,8 +698,7 @@ def can_replace_layer(
689698 packed_modules_list : List ,
690699 model_config : Optional [PretrainedConfig ],
691700 ) -> bool :
692- source_layer = cls .get_source_layer (source_layer )
693- return (source_layer is MergedColumnParallelLinear
701+ return (type (source_layer ) is MergedColumnParallelLinear
694702 and len (packed_modules_list ) == 2 )
695703
696704
@@ -758,8 +766,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
758766 def can_replace_layer (cls , source_layer : nn .Module ,
759767 lora_config : LoRAConfig , packed_modules_list : List ,
760768 model_config : Optional [PretrainedConfig ]) -> bool :
761- source_layer = cls .get_source_layer (source_layer )
762- return source_layer is QKVParallelLinear and len (
769+ return type (source_layer ) is QKVParallelLinear and len (
763770 packed_modules_list ) == 1
764771
765772
@@ -820,8 +827,7 @@ def can_replace_layer(
820827 packed_modules_list : List ,
821828 model_config : Optional [PretrainedConfig ],
822829 ) -> bool :
823- source_layer = cls .get_source_layer (source_layer )
824- return (source_layer is QKVParallelLinear
830+ return (type (source_layer ) is QKVParallelLinear
825831 and len (packed_modules_list ) == 3 )
826832
827833
@@ -855,7 +861,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
855861
856862 def forward (
857863 self , input_ : torch .Tensor
858- ) -> Tuple [ Optional [torch .Tensor ] , Optional [torch .Tensor ]]:
864+ ) -> Union [ torch . Tensor , tuple [torch .Tensor , Optional [torch .Tensor ] ]]:
859865 """Forward of RowParallelLinear
860866
861867 Args:
@@ -890,6 +896,10 @@ def forward(
890896 else :
891897 output = output_
892898 output_bias = self .base_layer .bias
899+
900+ if not self .base_layer .return_bias :
901+ return output
902+
893903 return output , output_bias
894904
895905 @property
@@ -906,8 +916,7 @@ def can_replace_layer(
906916 packed_modules_list : List ,
907917 model_config : Optional [PretrainedConfig ],
908918 ) -> bool :
909- source_layer = cls .get_source_layer (source_layer )
910- return source_layer is RowParallelLinear
919+ return type (source_layer ) is RowParallelLinear
911920
912921
913922class LogitsProcessorWithLoRA (BaseLayerWithLoRA ):
0 commit comments