@@ -62,6 +62,7 @@ def __init__(
6262        prefix : str  =  "" ,
6363        * ,
6464        return_bias : bool  =  True ,
65+         disable_tp : bool  =  False ,
6566    ):
6667        self .comm_group  =  None 
6768        if  prefix .find ("gate_up_proj" ) !=  - 1  and  mlp_tp_enable ():
@@ -88,7 +89,8 @@ def __init__(
8889                                  params_dtype ,
8990                                  quant_config ,
9091                                  prefix ,
91-                                   return_bias = return_bias )
92+                                   return_bias = return_bias ,
93+                                   disable_tp = disable_tp )
9294
9395        self .gather_output  =  gather_output 
9496
@@ -137,6 +139,7 @@ def __init__(
137139        prefix : str  =  "" ,
138140        * ,
139141        return_bias : bool  =  True ,
142+         disable_tp : bool  =  False ,
140143    ):
141144        if  prefix .find ("down_proj" ) !=  - 1  and  mlp_tp_enable ():
142145            comm_group  =  get_mlp_tp_group ()
@@ -156,6 +159,7 @@ def __init__(
156159            self .forward_type  =  "normal" 
157160        self .comm_group  =  comm_group 
158161
162+         # TODO: check for disable_tp 
159163        self .tp_size  =  self .comm_group .world_size 
160164        self .tp_rank  =  self .comm_group .rank_in_group 
161165
@@ -171,7 +175,8 @@ def __init__(
171175                                  params_dtype ,
172176                                  quant_config ,
173177                                  prefix ,
174-                                   return_bias = return_bias )
178+                                   return_bias = return_bias ,
179+                                   disable_tp = disable_tp )
175180
176181        self .input_is_parallel  =  input_is_parallel 
177182        self .reduce_results  =  reduce_results 
@@ -392,6 +397,7 @@ def __init__(
392397        prefix : str  =  "" ,
393398        * ,
394399        return_bias : bool  =  True ,
400+         disable_tp : bool  =  False ,
395401    ):
396402        if  prefix .find ("gate_up_proj" ) !=  - 1  and  mlp_tp_enable ():
397403            comm_group  =  get_mlp_tp_group ()
@@ -403,6 +409,7 @@ def __init__(
403409            comm_group  =  get_tp_group ()
404410            self .forward_type  =  "normal_tp" 
405411        self .comm_group  =  comm_group 
412+         # TODO: check for disable_tp 
406413        self .tp_rank  =  comm_group .rank_in_group 
407414        self .tp_size  =  comm_group .world_size 
408415
@@ -418,7 +425,8 @@ def __init__(
418425                                            params_dtype = params_dtype ,
419426                                            quant_config = quant_config ,
420427                                            prefix = prefix ,
421-                                             return_bias = return_bias )
428+                                             return_bias = return_bias ,
429+                                             disable_tp = disable_tp )
422430
423431    def  forward (
424432        self ,
@@ -498,6 +506,7 @@ def __init__(
498506        prefix : str  =  "" ,
499507        * ,
500508        return_bias : bool  =  True ,
509+         disable_tp : bool  =  False ,
501510    ):
502511        if  dense_optim_enable ():
503512            self .forward_type  =  "dense_optim" 
@@ -511,6 +520,7 @@ def __init__(
511520            total_num_kv_heads  =  total_num_heads 
512521        self .total_num_kv_heads  =  total_num_kv_heads 
513522        # Divide the weight matrix along the last dimension. 
523+         # TODO: check for disable_tp 
514524        tp_size  =  self .comm_group .world_size 
515525        self .num_heads  =  divide (self .total_num_heads , tp_size )
516526        if  tp_size  >=  self .total_num_kv_heads :
@@ -537,7 +547,8 @@ def __init__(
537547                                            params_dtype = params_dtype ,
538548                                            quant_config = quant_config ,
539549                                            prefix = prefix ,
540-                                             return_bias = return_bias )
550+                                             return_bias = return_bias ,
551+                                             disable_tp = disable_tp )
541552
542553    def  forward (
543554        self ,
@@ -611,4 +622,4 @@ def __init__(
611622            self .quant_method  =  quant_config .get_quant_method (self ,
612623                                                              prefix = prefix )
613624        self .return_bias  =  return_bias 
614-         self .disable_tp  =  disable_tp 
625+         self .disable_tp  =  disable_tp 
0 commit comments