88from  transformers  import  PretrainedConfig 
99
1010from  vllm .config .lora  import  LoRAConfig 
11- from  vllm .distributed  import  (get_tensor_model_parallel_rank ,
12-                               get_tensor_model_parallel_world_size ,
13-                               tensor_model_parallel_all_gather )
11+ from  vllm .distributed  import  tensor_model_parallel_all_gather 
1412from  vllm .distributed .utils  import  divide 
1513from  vllm .model_executor .layers .linear  import  (ColumnParallelLinear ,
1614                                               MergedColumnParallelLinear ,
@@ -85,7 +83,6 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None:
8583        # inconsistent when TP is greater than 1. 
8684        self .is_merged_col_linear  =  type (
8785            base_layer ) is  MergedColumnParallelLinear 
88-         self .tp_size  =  get_tensor_model_parallel_world_size ()
8986        self .output_size  =  self .base_layer .output_size_per_partition 
9087        # There is only one LoRA layer 
9188        self .n_slices  =  1 
@@ -97,33 +94,30 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
9794        # Applicable to cases where the base_layer is 
9895        # MergedColumnParallelLinear. 
9996        if  self .is_merged_col_linear :
100-             tp_rank  =  get_tensor_model_parallel_rank ()
10197            shard_size  =  self .output_size  //  2 
10298            offset  =  lora_b .shape [0 ] //  2 
10399
104-             left_weight  =  lora_b [tp_rank  *  shard_size :(tp_rank  +  1 ) * 
100+             left_weight  =  lora_b [self . tp_rank  *  shard_size :(self . tp_rank  +  1 ) * 
105101                                 shard_size , :]
106-             right_weight  =  lora_b [offset  +  tp_rank  *  shard_size :offset  + 
107-                                   (tp_rank  +  1 ) *  shard_size , :]
102+             right_weight  =  lora_b [offset  +  self . tp_rank  *  shard_size :offset  + 
103+                                   (self . tp_rank  +  1 ) *  shard_size , :]
108104            lora_b  =  torch .cat ([left_weight , right_weight ], dim = 0 )
109105        # Applicable to cases where the base_layer is 
110106        # ColumnParallelLinear. 
111107        else :
112-             tensor_model_parallel_rank  =  get_tensor_model_parallel_rank ()
113108            shard_size  =  self .output_size 
114-             start_idx  =  tensor_model_parallel_rank  *  shard_size 
115-             end_idx  =  (tensor_model_parallel_rank  +  1 ) *  shard_size 
109+             start_idx  =  self . tp_rank  *  shard_size 
110+             end_idx  =  (self . tp_rank  +  1 ) *  shard_size 
116111            lora_b  =  lora_b [start_idx :end_idx , :]
117112        return  lora_b 
118113
119114    def  slice_bias (self , bias : torch .Tensor ) ->  torch .Tensor :
120115        # TODO: Fix the slicing logic of bias. 
121116        if  bias  is  None :
122117            return  bias 
123-         tensor_model_parallel_rank  =  get_tensor_model_parallel_rank ()
124118        shard_size  =  self .output_size 
125-         start_idx  =  tensor_model_parallel_rank  *  shard_size 
126-         end_idx  =  (tensor_model_parallel_rank  +  1 ) *  shard_size 
119+         start_idx  =  self . tp_rank  *  shard_size 
120+         end_idx  =  (self . tp_rank  +  1 ) *  shard_size 
127121        bias  =  bias [start_idx :end_idx ]
128122        return  bias 
129123
@@ -144,7 +138,7 @@ def forward(
144138
145139        # Matrix multiply. 
146140        output_parallel  =  self .apply (input_ , bias )
147-         if  self .base_layer .gather_output :
141+         if  self .base_layer .gather_output   and   self . tp_size   >   1 :
148142            # All-gather across the partitions. 
149143            output  =  tensor_model_parallel_all_gather (output_parallel )
150144        else :
@@ -185,8 +179,6 @@ def __init__(
185179                                QKVParallelLinear ]) ->  None :
186180        super ().__init__ (base_layer )
187181        # There are two LoRA layers 
188-         self .tp_size  =  get_tensor_model_parallel_world_size ()
189-         self .tp_rank  =  get_tensor_model_parallel_rank ()
190182        # the output_sizes in MergedColumnParallelLinear is not sharded by tp 
191183        # we need to divide it by the tp_size to get correct slices size 
192184        output_sizes  =  self .base_layer .output_sizes 
@@ -341,9 +333,9 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
341333        self .n_slices  =  1 
342334
343335    def  slice_lora_b (self , lora_b : torch .Tensor ) ->  torch .Tensor :
344-          tp_rank   =   get_tensor_model_parallel_rank () 
345-         self .q_shard_id  =  tp_rank 
346-         self .kv_shard_id  =  tp_rank  //  self .base_layer .num_kv_head_replicas 
336+ 
337+         self .q_shard_id  =  self . tp_rank 
338+         self .kv_shard_id  =  self . tp_rank  //  self .base_layer .num_kv_head_replicas 
347339        lora_b_q  =  lora_b [self .q_proj_shard_size  * 
348340                          self .q_shard_id :self .q_proj_shard_size  * 
349341                          (self .q_shard_id  +  1 ), :]
@@ -397,8 +389,6 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
397389        super ().__init__ (base_layer )
398390        # There are three LoRA layer. 
399391        self .n_slices  =  len (self .base_layer .output_sizes )
400-         self .tp_size  =  get_tensor_model_parallel_world_size ()
401-         self .tp_rank  =  get_tensor_model_parallel_rank ()
402392
403393        self .q_proj_shard_size  =  (self .base_layer .num_heads  * 
404394                                  self .base_layer .head_size )
@@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
461451    # Therefore, the sharding of `lora_a` only needs to correspond with the 
462452    # gather operation. 
463453    def  slice_lora_a (self , lora_a : torch .Tensor ) ->  torch .Tensor :
464-         tp_rank  =  get_tensor_model_parallel_rank ()
465454        shard_size  =  self .lora_a_stacked [0 ].shape [2 ]
466-         start_idx  =  tp_rank  *  shard_size 
455+         start_idx  =  self . tp_rank  *  shard_size 
467456        lora_a  =  lora_a [start_idx :start_idx  +  shard_size , :]
468457        return  lora_a 
469458
@@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
547536    """ 
548537
549538    def  slice_lora_a (self , lora_a : torch .Tensor ) ->  torch .Tensor :
550-         tp_rank  =  get_tensor_model_parallel_rank ()
551539        shard_size  =  self .lora_a_stacked [0 ].shape [2 ]
552-         start_idx  =  tp_rank  *  shard_size 
540+         start_idx  =  self . tp_rank  *  shard_size 
553541        lora_a  =  lora_a [start_idx :start_idx  +  shard_size , :]
554542        return  lora_a 
555543
0 commit comments