Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/layers/vllm/test_compressed_tensors_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
import tempfile

import jax
import jax.numpy as jnp
import pytest
import torch
import torch.nn.functional as F
import torchax
import utils as test_utils
from compressed_tensors.quantization import QuantizationArgs
from jax.sharding import PartitionSpec
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.fused_moe import FusedMoE
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)

from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
VllmCompressedTensorsConfig
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
VllmCompressedTensorsW8A8Fp8MoEMethod

# yapf: enable

P = PartitionSpec

os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'

MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'


@pytest.fixture(autouse=True)
def setup_environment():
# This is a fake config used for init dist env.
# RowParallelLinear needs dist env to be initialized.
engine_args = EngineArgs(
model=MODEL,
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)

vllm_config = engine_args.create_engine_config()

with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
ensure_model_parallel_initialized(1, 1)


def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
seqlen = x.shape[0]
expert_weights = F.softmax(router_logits, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)

# cond ffn
# e = total num of exp = 160
# t = seqlen
# o = config.imtermediate size
# i = config.dim
x1 = torch.einsum("ti, eoi -> teo", x, w1)
x1 = F.silu(x1)
x3 = torch.einsum("ti, eoi -> teo", x, w3)
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)

seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
expert_outs = expert_outs[seq_indexes, expert_indices]
out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
return out


def test_fused_moe_method():
mesh = test_utils.get_spmd_mesh(jax.local_device_count())

engine_args = EngineArgs(
model=MODEL,
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
quant_config = get_tpu_quantization_config(vllm_config, mesh)

num_experts = 8
top_k = 2
hidden_size = 128
intermediate_size = hidden_size * 2

with set_current_vllm_config(vllm_config):
layer = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size)
quant_config = VllmCompressedTensorsConfig(
target_scheme_map={
'Linear': {
'weights':
QuantizationArgs(num_bits=8,
type='float',
symmetric=True,
group_size=None,
strategy='channel',
block_structure=None,
dynamic=False,
actorder=None,
observer='minmax',
observer_kwargs={}),
'input_activations':
QuantizationArgs(num_bits=8,
type='float',
symmetric=True,
group_size=None,
strategy='token',
block_structure=None,
dynamic=True,
actorder=None,
observer=None,
observer_kwargs={}),
'format':
None
}
},
ignore=[],
quant_format='compressed-tensors',
sparsity_scheme_map={},
sparsity_ignore_list=[],
)
moe = FusedMoEConfig(
num_experts=8,
experts_per_token=2,
hidden_dim=hidden_size,
num_local_experts=8,
moe_parallel_config=FusedMoEParallelConfig(
tp_size=1,
dp_size=1,
ep_size=1,
tp_rank=0,
dp_rank=0,
ep_rank=0,
use_ep=False,
all2all_backend='',
),
in_dtype=torch.bfloat16,
)
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
method.create_weights(layer,
num_experts,
hidden_size,
intermediate_size,
params_dtype=torch.float8_e4m3fn)
method.process_weights_after_loading(layer)

seqlen = 10
with torchax.default_env():
x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
router_logits = torch.randn((seqlen, num_experts),
dtype=torch.bfloat16).to('jax')
result = method.apply(layer,
x,
router_logits,
top_k=2,
renormalize=True)

result_reference = _ref_math_in_bf16(
layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
router_logits, top_k)

assert jnp.allclose(result.jax(), result_reference.jax())
5 changes: 3 additions & 2 deletions tpu_inference/layers/vllm/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
"compressed-tensors": VllmCompressedTensorsConfig,
"awq": VllmAWQConfig,
}

if model_config.quantization not in method_to_config:
raise NotImplementedError
raise NotImplementedError(
f"{model_config.quantization} quantization method not supported."
f" Supported methods are {method_to_config.keys()}")
quant_config = method_to_config[model_config.quantization]
assert issubclass(quant_config, JaxCommonConfig)
quant_config.set_configs(vllm_config, mesh)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
CompressedTensorsLinearMethod, CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
find_matched_target, should_ignore_layer)

from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
VllmCompressedTensorsW8A8Fp8MoEMethod
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
VllmCompressedTensorsW8A8Fp8
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
Expand Down Expand Up @@ -60,12 +61,12 @@ def get_scheme(self,
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
fused_mapping=self.packed_modules_mapping,
)

scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")

if weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is "
Expand All @@ -74,10 +75,6 @@ def get_scheme(self,
return None

# TODO(kyuyeunk): Add support for different act_quant_format
act_quant_format = is_activation_quantization_format( # noqa: F841
format
) if format is not None else is_activation_quantization_format(
self.quant_format)

linear_config = self.get_linear_config(layer)
if self._is_fp8_w8a8(weight_quant, input_quant):
Expand Down Expand Up @@ -114,8 +111,8 @@ def get_quant_method(
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, FusedMoE):
raise NotImplementedError(
"FusedMoE quantization is currently not supported.")
return VllmCompressedTensorsW8A8Fp8MoEMethod(
self, layer.quant_config, self.mesh)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
return None
Loading