Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/models/jax/common/moe/test_deepseek_moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import unittest

import jax
Expand Down
200 changes: 200 additions & 0 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flax import nnx
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from qwix._src.providers import ptq

import tpu_commons.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402
from tpu_commons.models.jax.model_loader import apply_qwix_quantization
Expand Down Expand Up @@ -631,5 +632,204 @@ def test_get_random_sharded_array_sharding_fallback(
self.assertEqual(fallback_sharding, NamedSharding(self.mesh, P()))


class TestManualQwixQuantization(unittest.TestCase):
"""Tests for manual Qwix quantization functions."""

def setUp(self):
if not jax.devices():
self.skipTest(
"JAX device not found, skipping JAX-dependent tests.")
self.weight = jnp.ones((4, 4))
self.inputs = jnp.ones((8, 4))
self.qtype = jnp.int8
self.channelwise_axes = [0]
self.tiled_axes = {}
self.calibration_method = 'max'

@patch(
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
)
def test_manually_quantize_qwix_weight(self, mock_create_param):
"""Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
quantize_qwix.manually_quantize_qwix_weight(
weight=self.weight,
qtype=self.qtype,
channelwise_axes=self.channelwise_axes,
tiled_axes=self.tiled_axes,
calibration_method=self.calibration_method)

mock_create_param.assert_called_once()
args, _ = mock_create_param.call_args
passed_weight, passed_how_to_quantize = args

self.assertTrue(jnp.array_equal(passed_weight, self.weight))
self.assertIsInstance(passed_how_to_quantize, ptq.qarray.HowToQuantize)
self.assertEqual(passed_how_to_quantize.qtype, self.qtype)
self.assertEqual(passed_how_to_quantize.channelwise_axes,
self.channelwise_axes)
self.assertEqual(passed_how_to_quantize.tiled_axes, self.tiled_axes)
self.assertEqual(passed_how_to_quantize.calibration_method,
self.calibration_method)

@patch(
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
)
@patch('qwix.pallas.get_current_rule')
def test_manually_quantize_qwix_activation(self, mock_get_rule,
mock_quantize_act):
"""Test that manually_quantize_qwix_activation calls ptq.quantize_act correctly."""
mock_rule = MagicMock()
mock_rule.act_static_scale = False
mock_get_rule.return_value = mock_rule
rule_name = "test_rule"

quantize_qwix.manually_quantize_qwix_activation(
inputs=self.inputs,
rule_name=rule_name,
qtype=self.qtype,
channelwise_axes=self.channelwise_axes,
tiled_axes=self.tiled_axes,
calibration_method=self.calibration_method)

mock_get_rule.assert_called_once_with(rule_name)
mock_quantize_act.assert_called_once()

args, _ = mock_quantize_act.call_args
passed_inputs, passed_how, passed_rule, passed_act_name = args

self.assertTrue(jnp.array_equal(passed_inputs, self.inputs))
self.assertIsInstance(passed_how, ptq.qarray.HowToQuantize)
self.assertEqual(passed_how.qtype, self.qtype)
self.assertEqual(passed_how.channelwise_axes, self.channelwise_axes)
self.assertEqual(passed_how.tiled_axes, self.tiled_axes)
self.assertEqual(passed_how.calibration_method,
self.calibration_method)
self.assertIs(passed_rule, mock_rule)
self.assertEqual(passed_act_name, "") # act_name is hardcoded to ""

@patch('qwix.pallas.get_current_rule')
def test_manually_quantize_qwix_activation_static_scale_raises_error(
self, mock_get_rule):
"""Test that an assertion is raised if the rule has static scale."""
mock_rule = MagicMock()
mock_rule.act_static_scale = True
mock_get_rule.return_value = mock_rule

with self.assertRaisesRegex(AssertionError,
"Static scale not supported right now"):
quantize_qwix.manually_quantize_qwix_activation(
inputs=self.inputs,
rule_name="any_rule",
qtype=self.qtype,
channelwise_axes=self.channelwise_axes,
tiled_axes=self.tiled_axes,
calibration_method=self.calibration_method)


class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
"""Tests for the get_quant_dtype_from_qwix_config function."""

def setUp(self):
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.additional_config = {}

def test_get_quant_dtype_success(self):
"""Test successful extraction of dtypes from a valid config."""
self.mock_vllm_config.additional_config = {
"quantization": {
"qwix": {
"scale_dtype":
"float16",
"rules": [
{
"module_path": ".*mlp.*",
"weight_qtype": "int4"
},
{
"module_path": ".*",
"weight_qtype": "int8"
},
],
}
}
}
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
self.mock_vllm_config)
self.assertEqual(scale_dtype, jnp.float16)
self.assertEqual(quant_dtype, jnp.int8)

def test_get_quant_dtype_default_scale(self):
"""Test that scale_dtype defaults to bfloat16 when not specified."""
self.mock_vllm_config.additional_config = {
"quantization": {
"qwix": {
"rules": [{
"module_path": ".*",
"weight_qtype": "int8"
}]
}
}
}
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
self.mock_vllm_config)
self.assertEqual(scale_dtype, jnp.bfloat16)
self.assertEqual(quant_dtype, jnp.int8)

def test_no_quantization_config_returns_defaults(self):
"""Test that default dtypes are returned when config is missing."""
self.mock_vllm_config.additional_config = {}
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
self.mock_vllm_config)
self.assertEqual(scale_dtype, jnp.bfloat16)
self.assertIsNone(quant_dtype)

def test_get_quant_dtype_no_wildcard_rule_returns_none(self):
"""Test that quant_dtype is None if no wildcard rule is found."""
self.mock_vllm_config.additional_config = {
"quantization": {
"qwix": {
"rules": [{
"module_path": ".*mlp.*",
"weight_qtype": "int4"
}]
}
}
}
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
self.mock_vllm_config)
self.assertEqual(scale_dtype, jnp.bfloat16)
self.assertIsNone(quant_dtype)

def test_get_quant_dtype_wildcard_rule_missing_qtype_raises_error(self):
"""Test that an assertion is raised if the wildcard rule is missing weight_qtype."""
self.mock_vllm_config.additional_config = {
"quantization": {
"qwix": {
"rules": [{
"module_path": ".*"
}]
}
}
}
with self.assertRaisesRegex(AssertionError,
"Quantization dtype not found"):
quantize_qwix.get_quant_dtype_from_qwix_config(
self.mock_vllm_config)

def test_get_quant_dtype_no_rules_key_returns_none(self):
"""Test that quant_dtype is None if 'rules' key is missing."""
self.mock_vllm_config.additional_config = {
"quantization": {
"qwix": {
"scale_dtype": "float16",
}
}
}
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
self.mock_vllm_config)
self.assertEqual(scale_dtype, jnp.float16)
self.assertIsNone(quant_dtype)


if __name__ == '__main__':
unittest.main()
47 changes: 36 additions & 11 deletions tpu_commons/models/jax/common/moe/deepseek_moe.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import enum
from dataclasses import InitVar, dataclass
from functools import partial
from typing import Tuple
from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from flax import nnx
from flax.typing import Sharding
from jax.sharding import PartitionSpec
from jaxtyping import Float
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
from qwix._src.providers import ptq

from tpu_commons.models.jax.common.base import create_param
from tpu_commons.models.jax.common.layers import FlaxUtils
from tpu_commons.models.jax.common.moe.moe import MoE
from tpu_commons.models.jax.utils.quantization.quantization_utils import (
manually_quantize_qwix_activation, manually_quantize_qwix_weight)

modeling_flax_utils = FlaxUtils()

Expand Down Expand Up @@ -141,6 +145,8 @@ class SparseMoE(MoE):
tile_size: tuple[int, int, int] = (128, 64, 128)
use_megablox: bool = False
mesh: jax.sharding.Mesh
# This should be set if and only if you have quantized your model (via Qwix)
quantized_dtype: Optional[jnp.dtype] = None

def __post_init__(self, rngs: nnx.Rngs):
super().__post_init__(rngs)
Expand Down Expand Up @@ -348,7 +354,11 @@ def _gmm(self, inputs, kernel, group_sizes):
raise NotImplementedError(
"MegaBlox kernel call is not implemented.")
else:
output = jax.lax.ragged_dot(
inputs = manually_quantize_qwix_activation(
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
"absmax") if self.quantized_dtype else inputs
ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
output = ragged_dot_func(
lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
Expand Down Expand Up @@ -572,12 +582,27 @@ def __call__(self, x_TD: Float):
check_rep=False)(
SparseMoE._distributed_sparse_moe_fwd)

return mapped_moe_fwd(
self,
x_TD,
router_weights_TX,
selected_experts_TX,
self.kernel_gating_EDF.value,
self.kernel_up_proj_EDF.value,
self.kernel_down_proj_EFD.value,
)
kernel_gating_EDF = self.kernel_gating_EDF.value
kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
kernel_down_proj_EFD = self.kernel_down_proj_EFD.value

if self.quantized_dtype:
if not isinstance(kernel_gating_EDF, ptq.WithAux):
kernel_gating_EDF = manually_quantize_qwix_weight(
kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
"absmax")
if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
kernel_up_proj_EDF = manually_quantize_qwix_weight(
kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
"absmax")
if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
kernel_down_proj_EFD = manually_quantize_qwix_weight(
kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
"absmax")
kernel_gating_EDF = kernel_gating_EDF.array
kernel_up_proj_EDF = kernel_up_proj_EDF.array
kernel_down_proj_EFD = kernel_down_proj_EFD.array

return mapped_moe_fwd(self, x_TD, router_weights_TX,
selected_experts_TX, kernel_gating_EDF,
kernel_up_proj_EDF, kernel_down_proj_EFD)
31 changes: 14 additions & 17 deletions tpu_commons/models/jax/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tpu_commons.models.jax.common.moe.moe import MoE
from tpu_commons.models.jax.common.transformer_block import (
SharedExpertsTransformerBlock, TransformerBlock)
from tpu_commons.models.jax.utils.quantization.quantization_utils import \
get_quant_dtype_from_qwix_config
from tpu_commons.models.jax.utils.weight_utils import (get_param,
model_weights_generator,
print_param_info,
Expand Down Expand Up @@ -219,6 +221,8 @@ def _create_mla() -> MLA:
activation_ffw_ted=('data', None, None),
edf_sharding=('model', None, None),
efd_sharding=('model', None, None),
quantized_dtype=self.weight_loader.quant_dtype
if self.weight_loader.is_model_quantized else None,
router=router) if is_moe_layer else DenseFFW(
dtype=dtype,
hidden_act=hidden_act,
Expand Down Expand Up @@ -460,20 +464,12 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
quantization_type = vllm_config.model_config.hf_config.quantization_config[
"quant_method"]
assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
# NOTE: this will only be used for loading in quantized weights (via Qwix)
qwix_config = vllm_config.additional_config.get(
"quantization", {}).get("qwix", {})
self.scale_dtype = getattr(
jnp, qwix_config.get("scale_dtype", "bfloat16"))
# TODO (jacobplatin): move this out of DeepSeek class to a utility function
for rule in qwix_config.get("rules", []):
if rule.get("module_path") == ".*":
quant_dtype_str = rule.get("weight_qtype", "")
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
self.quant_dtype = getattr(jnp, quant_dtype_str)
logger.info(
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
)
self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
vllm_config)

logger.info(
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
)

quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
"weight_block_size"]
Expand Down Expand Up @@ -586,15 +582,16 @@ def _load_individual_weight(self,
# Convert weights from torch into numpy
cast_type = model_weight.value.dtype


torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))

if torch_view_type:
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
# raw data as integers before converting to a JAX array.
weight_np = jnp.array(weight.view(torch_view_type).numpy()).view(cast_type)
weight_np = jnp.array(
weight.view(torch_view_type).numpy()).view(cast_type)
else:
raise ValueError(f"Unsupported dtype for tensor conversion: {cast_type}")
raise ValueError(
f"Unsupported dtype for tensor conversion: {cast_type}")

if scale is not None:
scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
Expand Down
Loading
Loading