From c0bd70858bd47fa9bba84d22a77f9f6e0e5c9876 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Tue, 4 Nov 2025 21:56:56 +0000 Subject: [PATCH 1/4] initial commit on compressed-tensors quantization support for fp8 Signed-off-by: Han Qi --- .../layers/vllm/quantization/__init__.py | 5 +- .../compressed_tensors/compressed_tensors.py | 21 +- .../compressed_tensors_moe.py | 348 ++++++++++++++++++ 3 files changed, 362 insertions(+), 12 deletions(-) create mode 100644 tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py diff --git a/tpu_inference/layers/vllm/quantization/__init__.py b/tpu_inference/layers/vllm/quantization/__init__.py index 0cc5ca575..cb6a365d2 100644 --- a/tpu_inference/layers/vllm/quantization/__init__.py +++ b/tpu_inference/layers/vllm/quantization/__init__.py @@ -22,9 +22,10 @@ def get_tpu_quantization_config(vllm_config: VllmConfig, "compressed-tensors": VllmCompressedTensorsConfig, "awq": VllmAWQConfig, } - if model_config.quantization not in method_to_config: - raise NotImplementedError + raise NotImplementedError( + f"{model_config.quantization} quantization method not supported." + f" Supported methods are {method_to_config.keys()}") quant_config = method_to_config[model_config.quantization] assert issubclass(quant_config, JaxCommonConfig) quant_config.set_configs(vllm_config, mesh) diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py index ca1a7c87f..33257b08e 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py @@ -14,10 +14,11 @@ CompressedTensorsConfig, CompressedTensorsKVCacheMethod, CompressedTensorsLinearMethod, CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, is_activation_quantization_format, - should_ignore_layer) + find_matched_target, should_ignore_layer) from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ + CompressedTensorsW8A8Fp8MoEMethod from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \ VllmCompressedTensorsW8A8Fp8 from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ @@ -60,12 +61,12 @@ def get_scheme(self, layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + ) scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") - format = scheme_dict.get("format") if weight_quant is None: logger.warning_once("Acceleration for non-quantized schemes is " @@ -74,10 +75,10 @@ def get_scheme(self, return None # TODO(kyuyeunk): Add support for different act_quant_format - act_quant_format = is_activation_quantization_format( # noqa: F841 - format - ) if format is not None else is_activation_quantization_format( - self.quant_format) + # act_quant_format = ( + # is_activation_quantization_format( # noqa: F841 + # format) if format is not None else + # is_activation_quantization_format(self.quant_format)) linear_config = self.get_linear_config(layer) if self._is_fp8_w8a8(weight_quant, input_quant): @@ -114,8 +115,8 @@ def get_quant_method( layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, FusedMoE): - raise NotImplementedError( - "FusedMoE quantization is currently not supported.") + return CompressedTensorsW8A8Fp8MoEMethod(self, layer.quant_config, + self.mesh) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) return None diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 000000000..ed823bc0c --- /dev/null +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,348 @@ +from typing import Callable, Optional, Union + +import jax +import jax.numpy as jnp +import torch +import torch.nn.functional as F +from compressed_tensors.quantization import QuantizationStrategy +from jax.experimental.layout import Format, Layout +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from torch.nn.parameter import Parameter +from torchax.interop import call_jax, torch_view +from torchax.ops.mappings import t2j +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) +from vllm.model_executor.layers.quantization.compressed_tensors import \ + compressed_tensors_moe as vllm_ct_moe +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ + CompressedTensorsConfig +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Fp8MoEMethod(vllm_ct_moe.CompressedTensorsMoEMethod + ): + + def __init__(self, quant_config: "CompressedTensorsConfig", + moe: FusedMoEConfig, mesh: Mesh): + super().__init__(moe) + self.mesh = mesh + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = False + self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() + + # cutlass path + self.is_fp8_w8a8_sm100 = False + self.use_cutlass = False + self.disable_expert_map = False + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + assert isinstance(layer, FusedMoE) + + intermediate_size = layer.w13_weight.shape[1] // 2 + w1_weight = layer.w13_weight[:, :intermediate_size] + w3_weight = layer.w13_weight[:, intermediate_size:] + w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size] + w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:] + + w2_weight = t2j(layer.w2_weight, use_dlpack=False) + w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16), + use_dlpack=False) + w1_weight = t2j(w1_weight, use_dlpack=False) + w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16), + use_dlpack=False) + w3_weight = t2j(w3_weight, use_dlpack=False) + w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16), + use_dlpack=False) + + if layer.use_ep: + format = Format(Layout((0, 1, 2)), + NamedSharding(self.mesh, P("model", None, None))) + w1_weight = jax.device_put(w1_weight, format) + w1_weight_scale = jax.device_put(w1_weight_scale, format) + w3_weight = jax.device_put(w3_weight, format) + w3_weight_scale = jax.device_put(w3_weight_scale, format) + w2_weight = jax.device_put(w2_weight, format) + w2_weight_scale = jax.device_put(w2_weight_scale, format) + else: + assert intermediate_size == w2_weight.shape[-1] + n_shards = self.mesh.shape["model"] + assert intermediate_size % n_shards == 0 + + # TODO: enable this if using fused weights + # output_sizes = [intermediate_size, intermediate_size] + # w13_weight = reorder_concatenated_tensor_for_sharding( + # w13_weight, output_sizes, n_shards, dim=1 + # ) + + w13_format = Format( + Layout((0, 1, 2)), + NamedSharding(self.mesh, P(None, "model", None))) + w1_weight = jax.device_put(w1_weight, w13_format) + w1_weight_scale = jax.device_put(w1_weight_scale, w13_format) + w3_weight = jax.device_put(w3_weight, w13_format) + w3_weight_scale = jax.device_put(w3_weight_scale, w13_format) + w2_weight = jax.device_put( + w2_weight, + Format(Layout((0, 1, 2)), + NamedSharding(self.mesh, P(None, None, "model"))), + ) + w2_weight_scale = jax.device_put( + w2_weight_scale, + Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())), + ) # replicate + + w1_weight = Parameter(torch_view(w1_weight), requires_grad=False) + w1_weight_scale = Parameter(torch_view(w1_weight_scale), + requires_grad=False) + w2_weight = Parameter(torch_view(w2_weight), requires_grad=False) + w2_weight_scale = Parameter(torch_view(w2_weight_scale), + requires_grad=False) + w3_weight = Parameter(torch_view(w3_weight), requires_grad=False) + w3_weight_scale = Parameter(torch_view(w3_weight_scale), + requires_grad=False) + + # TODO dont reuse variable + layer.w13_weight = w1_weight + layer.w13_weight_scale = w1_weight_scale + layer.w2_weight = w2_weight + layer.w2_weight_scale = w2_weight_scale + layer.w3_weight = w3_weight + layer.w3_weight_scale = w3_weight_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert isinstance(layer, FusedMoE) + if activation != "silu": + raise NotImplementedError( + "Only silu is supported for activation function.") + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax is supported for scoring_func") + + # TODO: Use MoE kernel when it supports fp8 + + seqlen = x.shape[0] + + expert_weights = F.softmax(router_logits, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, + top_k, + dim=-1) + if renormalize: + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) + + # cond ffn + # e = total num of exp = 160 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale) + ux1 = call_jax(jax.lax.dot, + x, + layer.w13_weight, + dimension_numbers=(((1, ), (2, )), ((), ())), + preferred_element_type=jnp.bfloat16.dtype) + x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2)) + + #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale + x3 = call_jax(jax.lax.dot, + x, + layer.w3_weight, + dimension_numbers=(((1, ), (2, )), ((), ())), + preferred_element_type=jnp.bfloat16.dtype + ) * layer.w3_weight_scale.squeeze(2) + + #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale + expert_outs = call_jax( + jax.lax.dot, + x1 * x3, + layer.w2_weight, + dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))), + preferred_element_type=jnp.bfloat16.dtype).transpose( + 0, 1) * layer.w2_weight_scale.squeeze(2) + + seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1) + expert_outs = expert_outs[seq_indexes, expert_indices] + + # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + out = call_jax(jax.lax.dot, + expert_outs, + expert_weights, + dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))), + preferred_element_type=jnp.bfloat16.dtype) + + return out + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + raise AssertionError('Blockwise quant for MoE not supported yet') + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: + if self.use_marlin: + return None + + per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_channel_quant, + block_shape=layer.weight_block_size, + ) From a73e34ac1693de0ae08662e725c6ad26fce8ced9 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 7 Nov 2025 04:22:55 +0000 Subject: [PATCH 2/4] address comments Signed-off-by: Han Qi --- .../compressed_tensors/compressed_tensors.py | 10 +- .../compressed_tensors_moe.py | 143 ++---------------- 2 files changed, 16 insertions(+), 137 deletions(-) diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py index 33257b08e..4f9d1127b 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py @@ -18,7 +18,7 @@ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ - CompressedTensorsW8A8Fp8MoEMethod + VllmCompressedTensorsW8A8Fp8MoEMethod from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \ VllmCompressedTensorsW8A8Fp8 from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ @@ -75,10 +75,6 @@ def get_scheme(self, return None # TODO(kyuyeunk): Add support for different act_quant_format - # act_quant_format = ( - # is_activation_quantization_format( # noqa: F841 - # format) if format is not None else - # is_activation_quantization_format(self.quant_format)) linear_config = self.get_linear_config(layer) if self._is_fp8_w8a8(weight_quant, input_quant): @@ -115,8 +111,8 @@ def get_quant_method( layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsW8A8Fp8MoEMethod(self, layer.quant_config, - self.mesh) + return VllmCompressedTensorsW8A8Fp8MoEMethod( + self, layer.quant_config, self.mesh) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) return None diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py index ed823bc0c..530595f98 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,27 +12,30 @@ from torchax.interop import call_jax, torch_view from torchax.ops.mappings import t2j from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) -from vllm.model_executor.layers.quantization.compressed_tensors import \ - compressed_tensors_moe as vllm_ct_moe +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ CompressedTensorsConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \ + CompressedTensorsW8A8Fp8MoEMethod from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) -from vllm.model_executor.utils import set_weight_attrs + +from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig logger = init_logger(__name__) -class CompressedTensorsW8A8Fp8MoEMethod(vllm_ct_moe.CompressedTensorsMoEMethod - ): +class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod, + JaxCommonConfig): def __init__(self, quant_config: "CompressedTensorsConfig", moe: FusedMoEConfig, mesh: Mesh): - super().__init__(moe) + super().__init__(quant_config, moe) + + self.use_marlin = False + self.use_cutlass = False + self.is_fp8_w8a8_sm100 = False + self.mesh = mesh self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( @@ -226,123 +229,3 @@ def apply( preferred_element_type=jnp.bfloat16.dtype) return out - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size - layer.num_experts = num_experts - layer.orig_dtype = params_dtype - layer.weight_block_size = None - - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - # Allocate 2 scales for w1 and w3 respectively. - # They are combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-TENSOR quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: - raise AssertionError('Blockwise quant for MoE not supported yet') - - # INPUT_SCALES - if self.static_input_scales: - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: - if self.use_marlin: - return None - - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL - - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_channel_quant, - block_shape=layer.weight_block_size, - ) From 6c68136f4e305d7ee2d6c3197c82f8938849c014 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 10 Nov 2025 19:31:02 +0000 Subject: [PATCH 3/4] Add unit test Signed-off-by: Han Qi --- .../vllm/test_compressed_tensors_moe.py | 185 ++++++++++++++++++ .../compressed_tensors_moe.py | 7 + 2 files changed, 192 insertions(+) create mode 100644 tests/layers/vllm/test_compressed_tensors_moe.py diff --git a/tests/layers/vllm/test_compressed_tensors_moe.py b/tests/layers/vllm/test_compressed_tensors_moe.py new file mode 100644 index 000000000..cb6bfc1f0 --- /dev/null +++ b/tests/layers/vllm/test_compressed_tensors_moe.py @@ -0,0 +1,185 @@ +import os +import tempfile + +import jax +import jax.numpy as jnp +import pytest +import torch +import torch.nn.functional as F +import torchax +import utils as test_utils +from compressed_tensors.quantization import QuantizationArgs +from jax.sharding import PartitionSpec +from vllm.config import set_current_vllm_config +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) + +from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \ + VllmCompressedTensorsConfig +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ + VllmCompressedTensorsW8A8Fp8MoEMethod + +# yapf: enable + +P = PartitionSpec + +os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1' + +MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic' + + +@pytest.fixture(autouse=True) +def setup_environment(): + # This is a fake config used for init dist env. + # RowParallelLinear needs dist env to be initialized. + engine_args = EngineArgs( + model=MODEL, + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + + +def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k): + seqlen = x.shape[0] + expert_weights = F.softmax(router_logits, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1) + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) + + # cond ffn + # e = total num of exp = 160 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + x1 = torch.einsum("ti, eoi -> teo", x, w1) + x1 = F.silu(x1) + x3 = torch.einsum("ti, eoi -> teo", x, w3) + expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2) + + seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1) + expert_outs = expert_outs[seq_indexes, expert_indices] + out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + return out + + +def test_fused_moe_method(): + mesh = test_utils.get_spmd_mesh(jax.local_device_count()) + + engine_args = EngineArgs( + model=MODEL, + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False + + # Call tpu_inference code + vllm_config.model_config.dtype = torch.bfloat16 + quant_config = get_tpu_quantization_config(vllm_config, mesh) + + num_experts = 8 + top_k = 2 + hidden_size = 128 + intermediate_size = hidden_size * 2 + + with set_current_vllm_config(vllm_config): + layer = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size) + quant_config = VllmCompressedTensorsConfig( + target_scheme_map={ + 'Linear': { + 'weights': + QuantizationArgs(num_bits=8, + type='float', + symmetric=True, + group_size=None, + strategy='channel', + block_structure=None, + dynamic=False, + actorder=None, + observer='minmax', + observer_kwargs={}), + 'input_activations': + QuantizationArgs(num_bits=8, + type='float', + symmetric=True, + group_size=None, + strategy='token', + block_structure=None, + dynamic=True, + actorder=None, + observer=None, + observer_kwargs={}), + 'format': + None + } + }, + ignore=[], + quant_format='compressed-tensors', + sparsity_scheme_map={}, + sparsity_ignore_list=[], + ) + moe = FusedMoEConfig( + num_experts=8, + experts_per_token=2, + hidden_dim=hidden_size, + num_local_experts=8, + moe_parallel_config=FusedMoEParallelConfig( + tp_size=1, + dp_size=1, + ep_size=1, + tp_rank=0, + dp_rank=0, + ep_rank=0, + use_ep=False, + all2all_backend='', + ), + in_dtype=torch.bfloat16, + ) + method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh) + method.create_weights(layer, + num_experts, + hidden_size, + intermediate_size, + params_dtype=torch.float8_e4m3fn) + method.process_weights_after_loading(layer) + + seqlen = 10 + with torchax.default_env(): + x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax') + router_logits = torch.randn((seqlen, num_experts), + dtype=torch.bfloat16).to('jax') + result = method.apply(layer, + x, + router_logits, + top_k=2, + renormalize=True) + + result_reference = _ref_math_in_bf16( + layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale, + layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale, + layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x, + router_logits, top_k) + + assert jnp.allclose(result.jax(), result_reference.jax()) diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py index 530595f98..69f9ca803 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -38,6 +38,9 @@ def __init__(self, quant_config: "CompressedTensorsConfig", self.mesh = mesh self.quant_config = quant_config + # import sys + # sys.stdin = open(0) + # breakpoint() self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( @@ -177,6 +180,10 @@ def apply( raise NotImplementedError( "Only softmax is supported for scoring_func") + # import sys + # sys.stdin = open(0) + # breakpoint() + # TODO: Use MoE kernel when it supports fp8 seqlen = x.shape[0] From 0e136dc5281d9e89ec261d2e0fd32571ed617544 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 10 Nov 2025 19:59:07 +0000 Subject: [PATCH 4/4] remove init Signed-off-by: Han Qi --- .../compressed_tensors_moe.py | 37 +------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py index 69f9ca803..3b04d623d 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -4,7 +4,6 @@ import jax.numpy as jnp import torch import torch.nn.functional as F -from compressed_tensors.quantization import QuantizationStrategy from jax.experimental.layout import Format, Layout from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P @@ -31,46 +30,12 @@ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod, def __init__(self, quant_config: "CompressedTensorsConfig", moe: FusedMoEConfig, mesh: Mesh): super().__init__(quant_config, moe) - - self.use_marlin = False - self.use_cutlass = False - self.is_fp8_w8a8_sm100 = False - self.mesh = mesh self.quant_config = quant_config - # import sys - # sys.stdin = open(0) - # breakpoint() - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") - per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy - == QuantizationStrategy.TENSOR) - per_channel = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN) - if not (per_tensor or per_channel): - assert self.weight_quant.strategy == QuantizationStrategy.BLOCK - self.weight_block_size = self.weight_quant.block_structure - assert self.weight_quant.dynamic is not None - else: - self.weight_block_size = None - self.block_quant = self.weight_block_size is not None - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales and per_channel: - raise ValueError( - "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization.") - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization + # disable GPU paths self.use_marlin = False self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() - - # cutlass path self.is_fp8_w8a8_sm100 = False self.use_cutlass = False self.disable_expert_map = False