Skip to content

Commit f447a37

Browse files
committed
Add unit test
1 parent a73e34a commit f447a37

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import os
2+
import tempfile
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import pytest
7+
import torch
8+
import torch.nn.functional as F
9+
import torchax
10+
import utils as test_utils
11+
from compressed_tensors.quantization import QuantizationArgs
12+
from jax.sharding import PartitionSpec
13+
from vllm.config import set_current_vllm_config
14+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
15+
init_distributed_environment)
16+
from vllm.engine.arg_utils import EngineArgs
17+
from vllm.model_executor.layers.fused_moe import FusedMoE
18+
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
19+
FusedMoEParallelConfig
20+
)
21+
22+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
23+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
24+
VllmCompressedTensorsConfig
25+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
26+
VllmCompressedTensorsW8A8Fp8MoEMethod
27+
28+
P = PartitionSpec
29+
30+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
31+
32+
MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
33+
34+
35+
@pytest.fixture(autouse=True)
36+
def setup_environment():
37+
# This is a fake config used for init dist env.
38+
# RowParallelLinear needs dist env to be initialized.
39+
engine_args = EngineArgs(
40+
model=MODEL,
41+
max_model_len=64,
42+
max_num_batched_tokens=64,
43+
max_num_seqs=4,
44+
)
45+
46+
vllm_config = engine_args.create_engine_config()
47+
48+
with set_current_vllm_config(vllm_config):
49+
temp_file = tempfile.mkstemp()[1]
50+
init_distributed_environment(
51+
1,
52+
0,
53+
local_rank=0,
54+
distributed_init_method=f"file://{temp_file}",
55+
backend="gloo")
56+
ensure_model_parallel_initialized(1, 1)
57+
58+
59+
def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
60+
seqlen = x.shape[0]
61+
expert_weights = F.softmax(router_logits, dim=-1)
62+
expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
63+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
64+
65+
# cond ffn
66+
# e = total num of exp = 160
67+
# t = seqlen
68+
# o = config.imtermediate size
69+
# i = config.dim
70+
x1 = torch.einsum("ti, eoi -> teo", x, w1)
71+
x1 = F.silu(x1)
72+
x3 = torch.einsum("ti, eoi -> teo", x, w3)
73+
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)
74+
75+
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
76+
expert_outs = expert_outs[seq_indexes, expert_indices]
77+
out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
78+
return out
79+
80+
81+
def test_fused_moe_method():
82+
mesh = test_utils.get_spmd_mesh(jax.local_device_count())
83+
84+
engine_args = EngineArgs(
85+
model=MODEL,
86+
max_model_len=64,
87+
max_num_batched_tokens=64,
88+
max_num_seqs=4,
89+
)
90+
vllm_config = engine_args.create_engine_config()
91+
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
92+
93+
# Call tpu_inference code
94+
vllm_config.model_config.dtype = torch.bfloat16
95+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
96+
97+
num_experts = 8
98+
top_k = 2
99+
hidden_size = 128
100+
intermediate_size = hidden_size * 2
101+
102+
with set_current_vllm_config(vllm_config):
103+
layer = FusedMoE(num_experts=num_experts,
104+
top_k=top_k,
105+
hidden_size=hidden_size,
106+
intermediate_size=intermediate_size)
107+
quant_config = VllmCompressedTensorsConfig(
108+
target_scheme_map={
109+
'Linear': {
110+
'weights':
111+
QuantizationArgs(num_bits=8,
112+
type='float',
113+
symmetric=True,
114+
group_size=None,
115+
strategy='channel',
116+
block_structure=None,
117+
dynamic=False,
118+
actorder=None,
119+
observer='minmax',
120+
observer_kwargs={}),
121+
'input_activations':
122+
QuantizationArgs(num_bits=8,
123+
type='float',
124+
symmetric=True,
125+
group_size=None,
126+
strategy='token',
127+
block_structure=None,
128+
dynamic=True,
129+
actorder=None,
130+
observer=None,
131+
observer_kwargs={}),
132+
'format':
133+
None
134+
}
135+
},
136+
ignore=[],
137+
quant_format='compressed-tensors',
138+
sparsity_scheme_map={},
139+
sparsity_ignore_list=[],
140+
)
141+
moe = FusedMoEConfig(
142+
num_experts=8,
143+
experts_per_token=2,
144+
hidden_dim=hidden_size,
145+
num_local_experts=8,
146+
moe_parallel_config=FusedMoEParallelConfig(
147+
tp_size=1,
148+
dp_size=1,
149+
ep_size=1,
150+
tp_rank=0,
151+
dp_rank=0,
152+
ep_rank=0,
153+
use_ep=False,
154+
all2all_backend='',
155+
),
156+
in_dtype=torch.bfloat16,
157+
)
158+
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
159+
method.create_weights(layer,
160+
num_experts,
161+
hidden_size,
162+
intermediate_size,
163+
params_dtype=torch.float8_e4m3fn)
164+
method.process_weights_after_loading(layer)
165+
166+
seqlen = 10
167+
with torchax.default_env():
168+
x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
169+
router_logits = torch.randn((seqlen, num_experts),
170+
dtype=torch.bfloat16).to('jax')
171+
result = method.apply(layer,
172+
x,
173+
router_logits,
174+
top_k=2,
175+
renormalize=True)
176+
177+
result_reference = _ref_math_in_bf16(
178+
layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
179+
layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
180+
layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
181+
router_logits, top_k)
182+
183+
assert jnp.allclose(result.jax(), result_reference.jax())

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(self, quant_config: "CompressedTensorsConfig",
3838

3939
self.mesh = mesh
4040
self.quant_config = quant_config
41+
# import sys
42+
# sys.stdin = open(0)
43+
# breakpoint()
4144
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
4245
"weights")
4346
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -177,6 +180,10 @@ def apply(
177180
raise NotImplementedError(
178181
"Only softmax is supported for scoring_func")
179182

183+
# import sys
184+
# sys.stdin = open(0)
185+
# breakpoint()
186+
180187
# TODO: Use MoE kernel when it supports fp8
181188

182189
seqlen = x.shape[0]

0 commit comments

Comments
 (0)