diff --git a/tests/layers/vllm/test_compressed_tensors_moe.py b/tests/layers/vllm/test_compressed_tensors_moe.py new file mode 100644 index 000000000..cb6bfc1f0 --- /dev/null +++ b/tests/layers/vllm/test_compressed_tensors_moe.py @@ -0,0 +1,185 @@ +import os +import tempfile + +import jax +import jax.numpy as jnp +import pytest +import torch +import torch.nn.functional as F +import torchax +import utils as test_utils +from compressed_tensors.quantization import QuantizationArgs +from jax.sharding import PartitionSpec +from vllm.config import set_current_vllm_config +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) + +from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \ + VllmCompressedTensorsConfig +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ + VllmCompressedTensorsW8A8Fp8MoEMethod + +# yapf: enable + +P = PartitionSpec + +os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1' + +MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic' + + +@pytest.fixture(autouse=True) +def setup_environment(): + # This is a fake config used for init dist env. + # RowParallelLinear needs dist env to be initialized. + engine_args = EngineArgs( + model=MODEL, + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + + +def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k): + seqlen = x.shape[0] + expert_weights = F.softmax(router_logits, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1) + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) + + # cond ffn + # e = total num of exp = 160 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + x1 = torch.einsum("ti, eoi -> teo", x, w1) + x1 = F.silu(x1) + x3 = torch.einsum("ti, eoi -> teo", x, w3) + expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2) + + seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1) + expert_outs = expert_outs[seq_indexes, expert_indices] + out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + return out + + +def test_fused_moe_method(): + mesh = test_utils.get_spmd_mesh(jax.local_device_count()) + + engine_args = EngineArgs( + model=MODEL, + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False + + # Call tpu_inference code + vllm_config.model_config.dtype = torch.bfloat16 + quant_config = get_tpu_quantization_config(vllm_config, mesh) + + num_experts = 8 + top_k = 2 + hidden_size = 128 + intermediate_size = hidden_size * 2 + + with set_current_vllm_config(vllm_config): + layer = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size) + quant_config = VllmCompressedTensorsConfig( + target_scheme_map={ + 'Linear': { + 'weights': + QuantizationArgs(num_bits=8, + type='float', + symmetric=True, + group_size=None, + strategy='channel', + block_structure=None, + dynamic=False, + actorder=None, + observer='minmax', + observer_kwargs={}), + 'input_activations': + QuantizationArgs(num_bits=8, + type='float', + symmetric=True, + group_size=None, + strategy='token', + block_structure=None, + dynamic=True, + actorder=None, + observer=None, + observer_kwargs={}), + 'format': + None + } + }, + ignore=[], + quant_format='compressed-tensors', + sparsity_scheme_map={}, + sparsity_ignore_list=[], + ) + moe = FusedMoEConfig( + num_experts=8, + experts_per_token=2, + hidden_dim=hidden_size, + num_local_experts=8, + moe_parallel_config=FusedMoEParallelConfig( + tp_size=1, + dp_size=1, + ep_size=1, + tp_rank=0, + dp_rank=0, + ep_rank=0, + use_ep=False, + all2all_backend='', + ), + in_dtype=torch.bfloat16, + ) + method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh) + method.create_weights(layer, + num_experts, + hidden_size, + intermediate_size, + params_dtype=torch.float8_e4m3fn) + method.process_weights_after_loading(layer) + + seqlen = 10 + with torchax.default_env(): + x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax') + router_logits = torch.randn((seqlen, num_experts), + dtype=torch.bfloat16).to('jax') + result = method.apply(layer, + x, + router_logits, + top_k=2, + renormalize=True) + + result_reference = _ref_math_in_bf16( + layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale, + layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale, + layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x, + router_logits, top_k) + + assert jnp.allclose(result.jax(), result_reference.jax()) diff --git a/tpu_inference/layers/vllm/quantization/__init__.py b/tpu_inference/layers/vllm/quantization/__init__.py index 0cc5ca575..cb6a365d2 100644 --- a/tpu_inference/layers/vllm/quantization/__init__.py +++ b/tpu_inference/layers/vllm/quantization/__init__.py @@ -22,9 +22,10 @@ def get_tpu_quantization_config(vllm_config: VllmConfig, "compressed-tensors": VllmCompressedTensorsConfig, "awq": VllmAWQConfig, } - if model_config.quantization not in method_to_config: - raise NotImplementedError + raise NotImplementedError( + f"{model_config.quantization} quantization method not supported." + f" Supported methods are {method_to_config.keys()}") quant_config = method_to_config[model_config.quantization] assert issubclass(quant_config, JaxCommonConfig) quant_config.set_configs(vllm_config, mesh) diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py index ca1a7c87f..4f9d1127b 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py @@ -14,10 +14,11 @@ CompressedTensorsConfig, CompressedTensorsKVCacheMethod, CompressedTensorsLinearMethod, CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, is_activation_quantization_format, - should_ignore_layer) + find_matched_target, should_ignore_layer) from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ + VllmCompressedTensorsW8A8Fp8MoEMethod from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \ VllmCompressedTensorsW8A8Fp8 from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ @@ -60,12 +61,12 @@ def get_scheme(self, layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + ) scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") - format = scheme_dict.get("format") if weight_quant is None: logger.warning_once("Acceleration for non-quantized schemes is " @@ -74,10 +75,6 @@ def get_scheme(self, return None # TODO(kyuyeunk): Add support for different act_quant_format - act_quant_format = is_activation_quantization_format( # noqa: F841 - format - ) if format is not None else is_activation_quantization_format( - self.quant_format) linear_config = self.get_linear_config(layer) if self._is_fp8_w8a8(weight_quant, input_quant): @@ -114,8 +111,8 @@ def get_quant_method( layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, FusedMoE): - raise NotImplementedError( - "FusedMoE quantization is currently not supported.") + return VllmCompressedTensorsW8A8Fp8MoEMethod( + self, layer.quant_config, self.mesh) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) return None diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 000000000..3b04d623d --- /dev/null +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,203 @@ +from typing import Callable, Optional, Union + +import jax +import jax.numpy as jnp +import torch +import torch.nn.functional as F +from jax.experimental.layout import Format, Layout +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from torch.nn.parameter import Parameter +from torchax.interop import call_jax, torch_view +from torchax.ops.mappings import t2j +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ + CompressedTensorsConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \ + CompressedTensorsW8A8Fp8MoEMethod +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) + +from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig + +logger = init_logger(__name__) + + +class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod, + JaxCommonConfig): + + def __init__(self, quant_config: "CompressedTensorsConfig", + moe: FusedMoEConfig, mesh: Mesh): + super().__init__(quant_config, moe) + self.mesh = mesh + self.quant_config = quant_config + + # disable GPU paths + self.use_marlin = False + self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() + self.is_fp8_w8a8_sm100 = False + self.use_cutlass = False + self.disable_expert_map = False + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + assert isinstance(layer, FusedMoE) + + intermediate_size = layer.w13_weight.shape[1] // 2 + w1_weight = layer.w13_weight[:, :intermediate_size] + w3_weight = layer.w13_weight[:, intermediate_size:] + w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size] + w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:] + + w2_weight = t2j(layer.w2_weight, use_dlpack=False) + w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16), + use_dlpack=False) + w1_weight = t2j(w1_weight, use_dlpack=False) + w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16), + use_dlpack=False) + w3_weight = t2j(w3_weight, use_dlpack=False) + w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16), + use_dlpack=False) + + if layer.use_ep: + format = Format(Layout((0, 1, 2)), + NamedSharding(self.mesh, P("model", None, None))) + w1_weight = jax.device_put(w1_weight, format) + w1_weight_scale = jax.device_put(w1_weight_scale, format) + w3_weight = jax.device_put(w3_weight, format) + w3_weight_scale = jax.device_put(w3_weight_scale, format) + w2_weight = jax.device_put(w2_weight, format) + w2_weight_scale = jax.device_put(w2_weight_scale, format) + else: + assert intermediate_size == w2_weight.shape[-1] + n_shards = self.mesh.shape["model"] + assert intermediate_size % n_shards == 0 + + # TODO: enable this if using fused weights + # output_sizes = [intermediate_size, intermediate_size] + # w13_weight = reorder_concatenated_tensor_for_sharding( + # w13_weight, output_sizes, n_shards, dim=1 + # ) + + w13_format = Format( + Layout((0, 1, 2)), + NamedSharding(self.mesh, P(None, "model", None))) + w1_weight = jax.device_put(w1_weight, w13_format) + w1_weight_scale = jax.device_put(w1_weight_scale, w13_format) + w3_weight = jax.device_put(w3_weight, w13_format) + w3_weight_scale = jax.device_put(w3_weight_scale, w13_format) + w2_weight = jax.device_put( + w2_weight, + Format(Layout((0, 1, 2)), + NamedSharding(self.mesh, P(None, None, "model"))), + ) + w2_weight_scale = jax.device_put( + w2_weight_scale, + Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())), + ) # replicate + + w1_weight = Parameter(torch_view(w1_weight), requires_grad=False) + w1_weight_scale = Parameter(torch_view(w1_weight_scale), + requires_grad=False) + w2_weight = Parameter(torch_view(w2_weight), requires_grad=False) + w2_weight_scale = Parameter(torch_view(w2_weight_scale), + requires_grad=False) + w3_weight = Parameter(torch_view(w3_weight), requires_grad=False) + w3_weight_scale = Parameter(torch_view(w3_weight_scale), + requires_grad=False) + + # TODO dont reuse variable + layer.w13_weight = w1_weight + layer.w13_weight_scale = w1_weight_scale + layer.w2_weight = w2_weight + layer.w2_weight_scale = w2_weight_scale + layer.w3_weight = w3_weight + layer.w3_weight_scale = w3_weight_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert isinstance(layer, FusedMoE) + if activation != "silu": + raise NotImplementedError( + "Only silu is supported for activation function.") + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax is supported for scoring_func") + + # import sys + # sys.stdin = open(0) + # breakpoint() + + # TODO: Use MoE kernel when it supports fp8 + + seqlen = x.shape[0] + + expert_weights = F.softmax(router_logits, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, + top_k, + dim=-1) + if renormalize: + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) + + # cond ffn + # e = total num of exp = 160 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale) + ux1 = call_jax(jax.lax.dot, + x, + layer.w13_weight, + dimension_numbers=(((1, ), (2, )), ((), ())), + preferred_element_type=jnp.bfloat16.dtype) + x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2)) + + #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale + x3 = call_jax(jax.lax.dot, + x, + layer.w3_weight, + dimension_numbers=(((1, ), (2, )), ((), ())), + preferred_element_type=jnp.bfloat16.dtype + ) * layer.w3_weight_scale.squeeze(2) + + #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale + expert_outs = call_jax( + jax.lax.dot, + x1 * x3, + layer.w2_weight, + dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))), + preferred_element_type=jnp.bfloat16.dtype).transpose( + 0, 1) * layer.w2_weight_scale.squeeze(2) + + seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1) + expert_outs = expert_outs[seq_indexes, expert_indices] + + # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + out = call_jax(jax.lax.dot, + expert_outs, + expert_weights, + dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))), + preferred_element_type=jnp.bfloat16.dtype) + + return out