@@ -99,21 +99,21 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
9999        if  self .is_merged_col_linear :
100100            tp_rank  =  get_tensor_model_parallel_rank ()
101101            shard_size  =  self .output_size  //  2 
102-             offset  =  lora_b .shape [- 1 ] //  2 
102+             offset  =  lora_b .shape [0 ] //  2 
103103
104-             left_weight  =  lora_b [:,  tp_rank  *  shard_size :(tp_rank  +  1 ) * 
105-                                  shard_size ]
106-             right_weight  =  lora_b [:,  offset  +  tp_rank  *  shard_size :offset  + 
107-                                   (tp_rank  +  1 ) *  shard_size ]
108-             lora_b  =  torch .cat ([left_weight , right_weight ], dim = 1 )
104+             left_weight  =  lora_b [tp_rank  *  shard_size :(tp_rank  +  1 ) * 
105+                                  shard_size , : ]
106+             right_weight  =  lora_b [offset  +  tp_rank  *  shard_size :offset  + 
107+                                   (tp_rank  +  1 ) *  shard_size , : ]
108+             lora_b  =  torch .cat ([left_weight , right_weight ], dim = 0 )
109109        # Applicable to cases where the base_layer is 
110110        # ColumnParallelLinear. 
111111        else :
112112            tensor_model_parallel_rank  =  get_tensor_model_parallel_rank ()
113113            shard_size  =  self .output_size 
114114            start_idx  =  tensor_model_parallel_rank  *  shard_size 
115115            end_idx  =  (tensor_model_parallel_rank  +  1 ) *  shard_size 
116-             lora_b  =  lora_b [:,  start_idx : end_idx ]
116+             lora_b  =  lora_b [start_idx : end_idx , : ]
117117        return  lora_b 
118118
119119    def  slice_bias (self , bias : torch .Tensor ) ->  torch .Tensor :
@@ -251,9 +251,8 @@ def slice_lora_b(
251251        for  i , (shard_id , shard_size ) in  enumerate (
252252                zip (self .output_ids , self .output_slices )):
253253            if  (lora_b_i  :=  lora_b [i ]) is  not None :
254-                 sliced_lora_b [i ] =  lora_b_i [:,
255-                                             shard_size  *  shard_id :shard_size  * 
256-                                             (shard_id  +  1 )]
254+                 sliced_lora_b [i ] =  lora_b_i [shard_size  *  shard_id :shard_size  * 
255+                                             (shard_id  +  1 ), :]
257256        return  sliced_lora_b 
258257
259258    def  slice_bias (
@@ -285,12 +284,12 @@ def set_lora(
285284        for  i  in  range (self .n_slices ):
286285            if  (lora_a_i  :=  lora_a [i ]) is  not None :
287286                self .lora_a_stacked [i ][
288-                     index , 0 , :lora_a_i .shape [1 ], :lora_a_i .shape [0 ]].copy_ (
289-                         lora_a_i . T , non_blocking = True )
287+                     index , 0 , :lora_a_i .shape [0 ], :lora_a_i .shape [1 ]].copy_ (
288+                         lora_a_i , non_blocking = True )
290289            if  (lora_b_i  :=  lora_b [i ]) is  not None :
291290                self .lora_b_stacked [i ][
292-                     index , 0 , :lora_b_i .shape [1 ], :lora_b_i .shape [0 ]].copy_ (
293-                         lora_b_i . T , non_blocking = True )
291+                     index , 0 , :lora_b_i .shape [0 ], :lora_b_i .shape [1 ]].copy_ (
292+                         lora_b_i , non_blocking = True )
294293
295294        if  lora_bias  is  not None :
296295            self .lora_bias_stacked  =  cast (tuple [torch .Tensor , ...],
@@ -299,7 +298,7 @@ def set_lora(
299298                if  (lora_bias_i  :=  lora_bias [i ]) is  not None :
300299                    self .lora_bias_stacked [i ][index ,
301300                                              0 , :lora_bias_i .shape [0 ]].copy_ (
302-                                                   lora_bias_i . T ,
301+                                                   lora_bias_i ,
303302                                                  non_blocking = True )
304303
305304    @classmethod  
@@ -345,18 +344,18 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
345344        tp_rank  =  get_tensor_model_parallel_rank ()
346345        self .q_shard_id  =  tp_rank 
347346        self .kv_shard_id  =  tp_rank  //  self .base_layer .num_kv_head_replicas 
348-         lora_b_q  =  lora_b [:,  self .q_proj_shard_size  * 
347+         lora_b_q  =  lora_b [self .q_proj_shard_size  * 
349348                          self .q_shard_id :self .q_proj_shard_size  * 
350-                           (self .q_shard_id  +  1 )]
349+                           (self .q_shard_id  +  1 ), : ]
351350        k_offset  =  self .q_proj_total_size 
352-         lora_b_k  =  lora_b [:,  k_offset  + 
351+         lora_b_k  =  lora_b [k_offset  + 
353352                          self .kv_proj_shard_size  *  self .kv_shard_id :k_offset  + 
354-                           self .kv_proj_shard_size  *  (self .kv_shard_id  +  1 )]
353+                           self .kv_proj_shard_size  *  (self .kv_shard_id  +  1 ), : ]
355354        v_offset  =  k_offset  +  self .kv_proj_total_size 
356-         lora_b_v  =  lora_b [:,  v_offset  + 
355+         lora_b_v  =  lora_b [v_offset  + 
357356                          self .kv_proj_shard_size  *  self .kv_shard_id :v_offset  + 
358-                           self .kv_proj_shard_size  *  (self .kv_shard_id  +  1 )]
359-         lora_b  =  torch .cat ([lora_b_q , lora_b_k , lora_b_v ], dim = 1 )
357+                           self .kv_proj_shard_size  *  (self .kv_shard_id  +  1 ), : ]
358+         lora_b  =  torch .cat ([lora_b_q , lora_b_k , lora_b_v ], dim = 0 )
360359        return  lora_b 
361360
362361    def  slice_bias (self , bias : torch .Tensor ) ->  torch .Tensor :
@@ -465,7 +464,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
465464        tp_rank  =  get_tensor_model_parallel_rank ()
466465        shard_size  =  self .lora_a_stacked [0 ].shape [2 ]
467466        start_idx  =  tp_rank  *  shard_size 
468-         lora_a  =  lora_a [:,  start_idx :start_idx  +  shard_size ]
467+         lora_a  =  lora_a [start_idx :start_idx  +  shard_size , : ]
469468        return  lora_a 
470469
471470    def  apply (self ,
@@ -508,10 +507,10 @@ def slice_lora_a(
508507        output_shard_size  =  self .lora_a_stacked [0 ].shape [2 ]
509508        output_start_idx  =  self .tp_rank  *  output_shard_size 
510509        lora_a  =  [
511-             lora_a [0 ][:,  output_start_idx :output_start_idx  + 
512-                       output_shard_size ] if  lora_a [0 ] is  not None  else  None ,
513-             lora_a [1 ][:,  output_start_idx :output_start_idx  + 
514-                       output_shard_size ] if  lora_a [1 ] is  not None  else  None ,
510+             lora_a [0 ][output_start_idx :output_start_idx  + 
511+                       output_shard_size , : ] if  lora_a [0 ] is  not None  else  None ,
512+             lora_a [1 ][output_start_idx :output_start_idx  + 
513+                       output_shard_size , : ] if  lora_a [1 ] is  not None  else  None ,
515514        ]
516515        return  lora_a 
517516
@@ -551,7 +550,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
551550        tp_rank  =  get_tensor_model_parallel_rank ()
552551        shard_size  =  self .lora_a_stacked [0 ].shape [2 ]
553552        start_idx  =  tp_rank  *  shard_size 
554-         lora_a  =  lora_a [:,  start_idx :start_idx  +  shard_size ]
553+         lora_a  =  lora_a [start_idx :start_idx  +  shard_size , : ]
555554        return  lora_a 
556555
557556    def  apply (self ,
@@ -589,12 +588,12 @@ def slice_lora_a(
589588        shard_size  =  [self .lora_a_stacked [i ].shape [2 ] for  i  in  range (3 )]
590589        start_idx  =  [self .tp_rank  *  shard_size [i ] for  i  in  range (3 )]
591590        lora_a  =  [
592-             lora_a [0 ][:,  start_idx [0 ]:start_idx [0 ] + 
593-                       shard_size [0 ]] if  lora_a [0 ] is  not None  else  None ,
594-             lora_a [1 ][:,  start_idx [1 ]:start_idx [1 ] + 
595-                       shard_size [1 ]] if  lora_a [1 ] is  not None  else  None ,
596-             lora_a [2 ][:,  start_idx [2 ]:start_idx [2 ] + 
597-                       shard_size [2 ]] if  lora_a [2 ] is  not None  else  None ,
591+             lora_a [0 ][start_idx [0 ]:start_idx [0 ] + 
592+                       shard_size [0 ], : ] if  lora_a [0 ] is  not None  else  None ,
593+             lora_a [1 ][start_idx [1 ]:start_idx [1 ] + 
594+                       shard_size [1 ], : ] if  lora_a [1 ] is  not None  else  None ,
595+             lora_a [2 ][start_idx [2 ]:start_idx [2 ] + 
596+                       shard_size [2 ], : ] if  lora_a [2 ] is  not None  else  None ,
598597        ]
599598        return  lora_a 
600599
0 commit comments