|
12 | 12 | from torchax.interop import call_jax, torch_view |
13 | 13 | from torchax.ops.mappings import t2j |
14 | 14 | from vllm.logger import init_logger |
15 | | -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, |
16 | | - FusedMoeWeightScaleSupported) |
17 | | -from vllm.model_executor.layers.fused_moe.config import ( |
18 | | - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) |
19 | | -from vllm.model_executor.layers.quantization.compressed_tensors import \ |
20 | | - compressed_tensors_moe as vllm_ct_moe |
| 15 | +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig |
21 | 16 | from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ |
22 | 17 | CompressedTensorsConfig |
| 18 | +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \ |
| 19 | + CompressedTensorsW8A8Fp8MoEMethod |
23 | 20 | from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa |
24 | 21 | WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) |
25 | | -from vllm.model_executor.utils import set_weight_attrs |
| 22 | + |
| 23 | +from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig |
26 | 24 |
|
27 | 25 | logger = init_logger(__name__) |
28 | 26 |
|
29 | 27 |
|
30 | | -class CompressedTensorsW8A8Fp8MoEMethod(vllm_ct_moe.CompressedTensorsMoEMethod |
31 | | - ): |
| 28 | +class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod, |
| 29 | + JaxCommonConfig): |
32 | 30 |
|
33 | 31 | def __init__(self, quant_config: "CompressedTensorsConfig", |
34 | 32 | moe: FusedMoEConfig, mesh: Mesh): |
35 | | - super().__init__(moe) |
| 33 | + super().__init__(quant_config, moe) |
| 34 | + |
| 35 | + self.use_marlin = False |
| 36 | + self.use_cutlass = False |
| 37 | + self.is_fp8_w8a8_sm100 = False |
| 38 | + |
36 | 39 | self.mesh = mesh |
37 | 40 | self.quant_config = quant_config |
38 | 41 | self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( |
@@ -226,123 +229,3 @@ def apply( |
226 | 229 | preferred_element_type=jnp.bfloat16.dtype) |
227 | 230 |
|
228 | 231 | return out |
229 | | - |
230 | | - def create_weights( |
231 | | - self, |
232 | | - layer: torch.nn.Module, |
233 | | - num_experts: int, |
234 | | - hidden_size: int, |
235 | | - intermediate_size_per_partition: int, |
236 | | - params_dtype: torch.dtype, |
237 | | - **extra_weight_attrs, |
238 | | - ): |
239 | | - layer.intermediate_size_per_partition = intermediate_size_per_partition |
240 | | - layer.hidden_size = hidden_size |
241 | | - layer.num_experts = num_experts |
242 | | - layer.orig_dtype = params_dtype |
243 | | - layer.weight_block_size = None |
244 | | - |
245 | | - params_dtype = torch.float8_e4m3fn |
246 | | - |
247 | | - # WEIGHTS |
248 | | - w13_weight = torch.nn.Parameter( |
249 | | - torch.empty( |
250 | | - num_experts, |
251 | | - 2 * intermediate_size_per_partition, |
252 | | - hidden_size, |
253 | | - dtype=params_dtype, |
254 | | - ), |
255 | | - requires_grad=False, |
256 | | - ) |
257 | | - layer.register_parameter("w13_weight", w13_weight) |
258 | | - set_weight_attrs(w13_weight, extra_weight_attrs) |
259 | | - |
260 | | - w2_weight = torch.nn.Parameter( |
261 | | - torch.empty( |
262 | | - num_experts, |
263 | | - hidden_size, |
264 | | - intermediate_size_per_partition, |
265 | | - dtype=params_dtype, |
266 | | - ), |
267 | | - requires_grad=False, |
268 | | - ) |
269 | | - layer.register_parameter("w2_weight", w2_weight) |
270 | | - set_weight_attrs(w2_weight, extra_weight_attrs) |
271 | | - |
272 | | - # WEIGHT_SCALES |
273 | | - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: |
274 | | - # Allocate 2 scales for w1 and w3 respectively. |
275 | | - # They are combined to a single scale after weight loading. |
276 | | - w13_weight_scale = torch.nn.Parameter(torch.ones( |
277 | | - num_experts, 2, dtype=torch.float32), |
278 | | - requires_grad=False) |
279 | | - layer.register_parameter("w13_weight_scale", w13_weight_scale) |
280 | | - w2_weight_scale = torch.nn.Parameter(torch.ones( |
281 | | - num_experts, dtype=torch.float32), |
282 | | - requires_grad=False) |
283 | | - layer.register_parameter("w2_weight_scale", w2_weight_scale) |
284 | | - # Add PER-TENSOR quantization for FusedMoE.weight_loader. |
285 | | - extra_weight_attrs.update( |
286 | | - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) |
287 | | - set_weight_attrs(w13_weight_scale, extra_weight_attrs) |
288 | | - set_weight_attrs(w2_weight_scale, extra_weight_attrs) |
289 | | - |
290 | | - elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: |
291 | | - w13_weight_scale = torch.nn.Parameter( |
292 | | - torch.ones( |
293 | | - num_experts, |
294 | | - 2 * intermediate_size_per_partition, |
295 | | - 1, |
296 | | - dtype=torch.float32, |
297 | | - ), |
298 | | - requires_grad=False, |
299 | | - ) |
300 | | - layer.register_parameter("w13_weight_scale", w13_weight_scale) |
301 | | - w2_weight_scale = torch.nn.Parameter( |
302 | | - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), |
303 | | - requires_grad=False, |
304 | | - ) |
305 | | - layer.register_parameter("w2_weight_scale", w2_weight_scale) |
306 | | - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. |
307 | | - extra_weight_attrs.update( |
308 | | - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) |
309 | | - set_weight_attrs(w13_weight_scale, extra_weight_attrs) |
310 | | - set_weight_attrs(w2_weight_scale, extra_weight_attrs) |
311 | | - |
312 | | - elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: |
313 | | - raise AssertionError('Blockwise quant for MoE not supported yet') |
314 | | - |
315 | | - # INPUT_SCALES |
316 | | - if self.static_input_scales: |
317 | | - w13_input_scale = torch.nn.Parameter(torch.ones( |
318 | | - num_experts, dtype=torch.float32), |
319 | | - requires_grad=False) |
320 | | - layer.register_parameter("w13_input_scale", w13_input_scale) |
321 | | - set_weight_attrs(w13_input_scale, extra_weight_attrs) |
322 | | - |
323 | | - w2_input_scale = torch.nn.Parameter(torch.ones( |
324 | | - num_experts, dtype=torch.float32), |
325 | | - requires_grad=False) |
326 | | - layer.register_parameter("w2_input_scale", w2_input_scale) |
327 | | - set_weight_attrs(w2_input_scale, extra_weight_attrs) |
328 | | - else: |
329 | | - layer.w13_input_scale = None |
330 | | - layer.w2_input_scale = None |
331 | | - |
332 | | - def get_fused_moe_quant_config( |
333 | | - self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: |
334 | | - if self.use_marlin: |
335 | | - return None |
336 | | - |
337 | | - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN |
338 | | - per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL |
339 | | - |
340 | | - return fp8_w8a8_moe_quant_config( |
341 | | - w1_scale=layer.w13_weight_scale, |
342 | | - w2_scale=layer.w2_weight_scale, |
343 | | - a1_scale=layer.w13_input_scale, |
344 | | - a2_scale=layer.w2_input_scale, |
345 | | - per_act_token_quant=per_act_token, |
346 | | - per_out_ch_quant=per_channel_quant, |
347 | | - block_shape=layer.weight_block_size, |
348 | | - ) |
0 commit comments