@@ -74,9 +74,10 @@ def wrapper(*args, **kwargs):
7474 global_num_experts = layer ._lora ["global_num_experts" ]
7575 expert_map = layer ._lora ["expert_map" ]
7676
77- (token_lora_mapping , _ , _ , _ , _ ,
78- _ ) = layer .punica_wrapper .token_mapping_meta .meta_args (
77+ (token_lora_mapping , _ , num_tokens_per_lora , _ , _ ,
78+ no_lora_flag_cpu ) = layer .punica_wrapper .token_mapping_meta .meta_args (
7979 hidden_states .size (0 ))
80+
8081 config_dtype = _get_config_dtype_str (use_fp8_w8a8 = False ,
8182 use_int8_w8a16 = False ,
8283 use_int4_w4a16 = False ,
@@ -99,7 +100,8 @@ def wrapper(*args, **kwargs):
99100 config = get_config_func (M )
100101 (sorted_token_ids_lora , expert_ids_lora ,
101102 num_tokens_post_padded_lora ) = (moe_lora_align_block_size (
102- curr_topk_ids , token_lora_mapping , config ['BLOCK_SIZE_M' ],
103+ curr_topk_ids , token_lora_mapping , num_tokens_per_lora , no_lora_flag_cpu ,
104+ layer .adapter_enabled , config ['BLOCK_SIZE_M' ],
103105 global_num_experts , curr_topk_ids .shape [- 1 ], expert_map ))
104106
105107 layer ._lora ["sorted_token_ids_lora" ] = sorted_token_ids_lora
@@ -132,6 +134,7 @@ def wrapper(*args, **kwargs):
132134 max_lora_rank ,
133135 top_k ,
134136 config ,
137+ layer .adapter_enabled ,
135138 )
136139
137140 result = func (* args , ** kwargs )
@@ -191,7 +194,7 @@ def wrapper(*args, **kwargs):
191194 intermediate_cache3 , intermediate_cache2 ,
192195 [w2_lora_a_stacked ], [w2_lora_b_stacked ], topk_weights ,
193196 sorted_token_ids_lora , expert_ids_lora ,
194- num_tokens_post_padded_lora , max_lora_rank , top_k , config ,
197+ num_tokens_post_padded_lora , max_lora_rank , top_k , config , layer . adapter_enabled ,
195198 True )
196199
197200 result = func (* args , ** kwargs )
@@ -226,6 +229,8 @@ def create_lora_weights(
226229 model_config : Optional [PretrainedConfig ] = None ,
227230 ) -> None :
228231 """Initializes lora matrices."""
232+ self .adapter_enabled = torch .tensor ([0 ] * (max_loras + 1 ), dtype = torch .int , device = self .device )
233+
229234 self .w1_lora_a_stacked = torch .zeros (
230235 (
231236 max_loras ,
@@ -288,6 +293,9 @@ def create_lora_weights(
288293 dtype = lora_config .lora_dtype ,
289294 device = self .device ,
290295 )
296+
297+ # flags to track which LoRAs have MoE adapters
298+ self .base_layer .adapter_enabled = self .adapter_enabled
291299
292300 self .base_layer .w1_lora_a_stacked = self .w1_lora_a_stacked
293301 self .base_layer .w1_lora_b_stacked = self .w1_lora_b_stacked
@@ -324,6 +332,8 @@ def reset_lora(self, index: int):
324332 self .w3_lora_b_stacked [index ] = 0
325333 self .w2_lora_a_stacked [index ] = 0
326334 self .w2_lora_b_stacked [index ] = 0
335+
336+ self .adapter_enabled [index ] = 0
327337
328338 def set_lora (
329339 self ,
@@ -334,6 +344,9 @@ def set_lora(
334344 bias : Optional [torch .Tensor ] = None ,
335345 ):
336346 """Overwrites lora tensors at index."""
347+
348+ self .adapter_enabled [index ] = 1
349+
337350 for eid in range (len (lora_a ) // 3 ):
338351 w1_lora_a = lora_a [eid * 3 ]
339352 w2_lora_a = lora_a [eid * 3 + 1 ]
0 commit comments