|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | + |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | +import torch.nn.functional as F |
| 9 | +import torchax |
| 10 | +import utils as test_utils |
| 11 | +from compressed_tensors.quantization import QuantizationArgs |
| 12 | +from jax.sharding import PartitionSpec |
| 13 | +from vllm.config import set_current_vllm_config |
| 14 | +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, |
| 15 | + init_distributed_environment) |
| 16 | +from vllm.engine.arg_utils import EngineArgs |
| 17 | +from vllm.model_executor.layers.fused_moe import FusedMoE |
| 18 | +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, |
| 19 | + FusedMoEParallelConfig |
| 20 | + ) |
| 21 | + |
| 22 | +from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config |
| 23 | +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \ |
| 24 | + VllmCompressedTensorsConfig |
| 25 | +from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ |
| 26 | + VllmCompressedTensorsW8A8Fp8MoEMethod |
| 27 | + |
| 28 | +P = PartitionSpec |
| 29 | + |
| 30 | +os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1' |
| 31 | + |
| 32 | +MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic' |
| 33 | + |
| 34 | + |
| 35 | +@pytest.fixture(autouse=True) |
| 36 | +def setup_environment(): |
| 37 | + # This is a fake config used for init dist env. |
| 38 | + # RowParallelLinear needs dist env to be initialized. |
| 39 | + engine_args = EngineArgs( |
| 40 | + model=MODEL, |
| 41 | + max_model_len=64, |
| 42 | + max_num_batched_tokens=64, |
| 43 | + max_num_seqs=4, |
| 44 | + ) |
| 45 | + |
| 46 | + vllm_config = engine_args.create_engine_config() |
| 47 | + |
| 48 | + with set_current_vllm_config(vllm_config): |
| 49 | + temp_file = tempfile.mkstemp()[1] |
| 50 | + init_distributed_environment( |
| 51 | + 1, |
| 52 | + 0, |
| 53 | + local_rank=0, |
| 54 | + distributed_init_method=f"file://{temp_file}", |
| 55 | + backend="gloo") |
| 56 | + ensure_model_parallel_initialized(1, 1) |
| 57 | + |
| 58 | + |
| 59 | +def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k): |
| 60 | + seqlen = x.shape[0] |
| 61 | + expert_weights = F.softmax(router_logits, dim=-1) |
| 62 | + expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1) |
| 63 | + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) |
| 64 | + |
| 65 | + # cond ffn |
| 66 | + # e = total num of exp = 160 |
| 67 | + # t = seqlen |
| 68 | + # o = config.imtermediate size |
| 69 | + # i = config.dim |
| 70 | + x1 = torch.einsum("ti, eoi -> teo", x, w1) |
| 71 | + x1 = F.silu(x1) |
| 72 | + x3 = torch.einsum("ti, eoi -> teo", x, w3) |
| 73 | + expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2) |
| 74 | + |
| 75 | + seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1) |
| 76 | + expert_outs = expert_outs[seq_indexes, expert_indices] |
| 77 | + out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights) |
| 78 | + return out |
| 79 | + |
| 80 | + |
| 81 | +def test_fused_moe_method(): |
| 82 | + mesh = test_utils.get_spmd_mesh(jax.local_device_count()) |
| 83 | + |
| 84 | + engine_args = EngineArgs( |
| 85 | + model=MODEL, |
| 86 | + max_model_len=64, |
| 87 | + max_num_batched_tokens=64, |
| 88 | + max_num_seqs=4, |
| 89 | + ) |
| 90 | + vllm_config = engine_args.create_engine_config() |
| 91 | + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False |
| 92 | + |
| 93 | + # Call tpu_inference code |
| 94 | + vllm_config.model_config.dtype = torch.bfloat16 |
| 95 | + quant_config = get_tpu_quantization_config(vllm_config, mesh) |
| 96 | + |
| 97 | + num_experts = 8 |
| 98 | + top_k = 2 |
| 99 | + hidden_size = 128 |
| 100 | + intermediate_size = hidden_size * 2 |
| 101 | + |
| 102 | + with set_current_vllm_config(vllm_config): |
| 103 | + layer = FusedMoE(num_experts=num_experts, |
| 104 | + top_k=top_k, |
| 105 | + hidden_size=hidden_size, |
| 106 | + intermediate_size=intermediate_size) |
| 107 | + quant_config = VllmCompressedTensorsConfig( |
| 108 | + target_scheme_map={ |
| 109 | + 'Linear': { |
| 110 | + 'weights': |
| 111 | + QuantizationArgs(num_bits=8, |
| 112 | + type='float', |
| 113 | + symmetric=True, |
| 114 | + group_size=None, |
| 115 | + strategy='channel', |
| 116 | + block_structure=None, |
| 117 | + dynamic=False, |
| 118 | + actorder=None, |
| 119 | + observer='minmax', |
| 120 | + observer_kwargs={}), |
| 121 | + 'input_activations': |
| 122 | + QuantizationArgs(num_bits=8, |
| 123 | + type='float', |
| 124 | + symmetric=True, |
| 125 | + group_size=None, |
| 126 | + strategy='token', |
| 127 | + block_structure=None, |
| 128 | + dynamic=True, |
| 129 | + actorder=None, |
| 130 | + observer=None, |
| 131 | + observer_kwargs={}), |
| 132 | + 'format': |
| 133 | + None |
| 134 | + } |
| 135 | + }, |
| 136 | + ignore=[], |
| 137 | + quant_format='compressed-tensors', |
| 138 | + sparsity_scheme_map={}, |
| 139 | + sparsity_ignore_list=[], |
| 140 | + ) |
| 141 | + moe = FusedMoEConfig( |
| 142 | + num_experts=8, |
| 143 | + experts_per_token=2, |
| 144 | + hidden_dim=hidden_size, |
| 145 | + num_local_experts=8, |
| 146 | + moe_parallel_config=FusedMoEParallelConfig( |
| 147 | + tp_size=1, |
| 148 | + dp_size=1, |
| 149 | + ep_size=1, |
| 150 | + tp_rank=0, |
| 151 | + dp_rank=0, |
| 152 | + ep_rank=0, |
| 153 | + use_ep=False, |
| 154 | + all2all_backend='', |
| 155 | + ), |
| 156 | + in_dtype=torch.bfloat16, |
| 157 | + ) |
| 158 | + method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh) |
| 159 | + method.create_weights(layer, |
| 160 | + num_experts, |
| 161 | + hidden_size, |
| 162 | + intermediate_size, |
| 163 | + params_dtype=torch.float8_e4m3fn) |
| 164 | + method.process_weights_after_loading(layer) |
| 165 | + |
| 166 | + seqlen = 10 |
| 167 | + with torchax.default_env(): |
| 168 | + x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax') |
| 169 | + router_logits = torch.randn((seqlen, num_experts), |
| 170 | + dtype=torch.bfloat16).to('jax') |
| 171 | + result = method.apply(layer, |
| 172 | + x, |
| 173 | + router_logits, |
| 174 | + top_k=2, |
| 175 | + renormalize=True) |
| 176 | + |
| 177 | + result_reference = _ref_math_in_bf16( |
| 178 | + layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale, |
| 179 | + layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale, |
| 180 | + layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x, |
| 181 | + router_logits, top_k) |
| 182 | + |
| 183 | + assert jnp.allclose(result.jax(), result_reference.jax()) |
0 commit comments