@@ -88,14 +88,23 @@ def __init__(
8888 self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
8989 "input_activations" )
9090
91- if not (self .weight_quant .strategy == QuantizationStrategy .TENSOR
92- and self .input_quant .strategy == QuantizationStrategy .TENSOR ):
91+ per_tensor = (self .weight_quant .strategy == QuantizationStrategy .TENSOR
92+ and self .input_quant .strategy
93+ == QuantizationStrategy .TENSOR )
94+ per_channel = (
95+ self .weight_quant .strategy == QuantizationStrategy .CHANNEL
96+ and self .input_quant .strategy == QuantizationStrategy .TOKEN )
97+ if not (per_tensor or per_channel ):
9398 raise ValueError (
94- "For FP8 Fused MoE layers, only per-tensor scales "
95- "for weights and activations are supported . Found "
99+ "For FP8 Fused MoE layers, we require per tensor "
100+ "or channelwise, dynamic per token quantization . Found "
96101 f"{ self .weight_quant } , { self .input_quant } " )
97102
98103 self .static_input_scales = not self .input_quant .dynamic
104+ if self .static_input_scales and per_channel :
105+ raise ValueError (
106+ "For FP8 Fused MoE layer, we require either per tensor or "
107+ "channelwise, dynamic per token quantization." )
99108
100109 def create_weights (self , layer : torch .nn .Module , num_experts : int ,
101110 hidden_size : int , intermediate_size_per_partition : int ,
@@ -123,24 +132,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
123132 set_weight_attrs (w2_weight , extra_weight_attrs )
124133
125134 # WEIGHT_SCALES
126- # Allocate 2 scales for w1 and w3 respectively.
127- # They will be combined to a single scale after weight loading.
128- w13_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
129- 2 ,
130- dtype = torch .float32 ),
131- requires_grad = False )
132- layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
135+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
136+ # Allocate 2 scales for w1 and w3 respectively.
137+ # They are combined to a single scale after weight loading.
138+ w13_weight_scale = torch .nn .Parameter (torch .ones (
139+ num_experts , 2 , dtype = torch .float32 ),
140+ requires_grad = False )
141+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
142+ w2_weight_scale = torch .nn .Parameter (torch .ones (
143+ num_experts , dtype = torch .float32 ),
144+ requires_grad = False )
145+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
146+ # Add PER-TENSOR quantization for FusedMoE.weight_loader.
147+ extra_weight_attrs .update (
148+ {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
149+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
150+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
133151
134- w2_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
135- dtype = torch .float32 ),
136- requires_grad = False )
137- layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
138- # Add the quantization method used (per tensor/grouped/channel)
139- # to ensure the weight scales are loaded in properly
140- extra_weight_attrs .update (
141- {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
142- set_weight_attrs (w13_weight_scale , extra_weight_attrs )
143- set_weight_attrs (w2_weight_scale , extra_weight_attrs )
152+ elif self .weight_quant .strategy == QuantizationStrategy .CHANNEL :
153+ w13_weight_scale = torch .nn .Parameter (torch .ones (
154+ num_experts ,
155+ 2 * intermediate_size_per_partition ,
156+ 1 ,
157+ dtype = torch .float32 ),
158+ requires_grad = False )
159+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
160+ w2_weight_scale = torch .nn .Parameter (torch .ones (
161+ num_experts , hidden_size , 1 , dtype = torch .float32 ),
162+ requires_grad = False )
163+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
164+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
165+ extra_weight_attrs .update (
166+ {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value })
167+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
168+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
144169
145170 # INPUT_SCALES
146171 if self .static_input_scales :
@@ -163,6 +188,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
163188 # Fp8 moe kernels require a single activation scale.
164189 # We take the max of all the scales in case they differ.
165190 if self .static_input_scales :
191+ assert self .input_quant .strategy == QuantizationStrategy .TENSOR
166192 if (layer .w13_input_scale is None or layer .w2_input_scale is None ):
167193 raise ValueError (
168194 "QuantConfig has static quantization, but found "
@@ -204,24 +230,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
204230 layer .w2_input_scale = torch .nn .Parameter (w2_input_scale ,
205231 requires_grad = False )
206232
207- # Fp8 moe kernel needs single weight scale for w13 per expert.
208- # We take the max then dequant and requant each expert.
209- assert layer .w13_weight_scale is not None
210- shard_size = layer .intermediate_size_per_partition
211- max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
212- for expert_id in range (layer .local_num_experts ):
213- start = 0
214- for shard_id in range (2 ):
215- dq_weight = per_tensor_dequantize (
216- layer .w13_weight [expert_id ][start :start + shard_size , :],
217- layer .w13_weight_scale [expert_id ][shard_id ])
218- layer .w13_weight [expert_id ][
219- start :start + shard_size , :], _ = ops .scaled_fp8_quant (
220- dq_weight , max_w13_scales [expert_id ])
221- start += shard_size
222-
223- layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
224- requires_grad = False )
233+ # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
234+ # for w13 per expert. Use max then dequant and requant each expert.
235+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
236+ assert layer .w13_weight_scale is not None
237+ shard_size = layer .intermediate_size_per_partition
238+ max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
239+ for expert_id in range (layer .local_num_experts ):
240+ start = 0
241+ for shard_id in range (2 ):
242+ dq_weight = per_tensor_dequantize (
243+ layer .w13_weight [expert_id ][start :start +
244+ shard_size , :],
245+ layer .w13_weight_scale [expert_id ][shard_id ])
246+ layer .w13_weight [expert_id ][
247+ start :start + shard_size , :], _ = ops .scaled_fp8_quant (
248+ dq_weight , max_w13_scales [expert_id ])
249+ start += shard_size
250+ layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
251+ requires_grad = False )
225252
226253 def apply (
227254 self ,
@@ -265,6 +292,8 @@ def apply(
265292 activation = activation ,
266293 apply_router_weight_on_input = apply_router_weight_on_input ,
267294 use_fp8_w8a8 = True ,
295+ per_channel_quant = self .weight_quant .strategy ==
296+ QuantizationStrategy .CHANNEL ,
268297 global_num_experts = global_num_experts ,
269298 expert_map = expert_map ,
270299 w1_scale = layer .w13_weight_scale ,
0 commit comments