Skip to content

Commit 136049d

Browse files
jrplatinbzgoogle
authored andcommitted
[JAX][Quantization] Add Qwix support for SparseMatul (#740)
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent 7f1a038 commit 136049d

File tree

5 files changed

+336
-27
lines changed

5 files changed

+336
-27
lines changed

tests/models/jax/common/moe/test_deepseek_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import unittest
32

43
import jax

tests/test_quantization.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from flax import nnx
1010
from jax.sharding import Mesh, NamedSharding
1111
from jax.sharding import PartitionSpec as P
12+
from qwix._src.providers import ptq
1213

1314
import tpu_commons.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402
1415
from tpu_commons.models.jax.model_loader import apply_qwix_quantization
@@ -631,5 +632,204 @@ def test_get_random_sharded_array_sharding_fallback(
631632
self.assertEqual(fallback_sharding, NamedSharding(self.mesh, P()))
632633

633634

635+
class TestManualQwixQuantization(unittest.TestCase):
636+
"""Tests for manual Qwix quantization functions."""
637+
638+
def setUp(self):
639+
if not jax.devices():
640+
self.skipTest(
641+
"JAX device not found, skipping JAX-dependent tests.")
642+
self.weight = jnp.ones((4, 4))
643+
self.inputs = jnp.ones((8, 4))
644+
self.qtype = jnp.int8
645+
self.channelwise_axes = [0]
646+
self.tiled_axes = {}
647+
self.calibration_method = 'max'
648+
649+
@patch(
650+
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
651+
)
652+
def test_manually_quantize_qwix_weight(self, mock_create_param):
653+
"""Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
654+
quantize_qwix.manually_quantize_qwix_weight(
655+
weight=self.weight,
656+
qtype=self.qtype,
657+
channelwise_axes=self.channelwise_axes,
658+
tiled_axes=self.tiled_axes,
659+
calibration_method=self.calibration_method)
660+
661+
mock_create_param.assert_called_once()
662+
args, _ = mock_create_param.call_args
663+
passed_weight, passed_how_to_quantize = args
664+
665+
self.assertTrue(jnp.array_equal(passed_weight, self.weight))
666+
self.assertIsInstance(passed_how_to_quantize, ptq.qarray.HowToQuantize)
667+
self.assertEqual(passed_how_to_quantize.qtype, self.qtype)
668+
self.assertEqual(passed_how_to_quantize.channelwise_axes,
669+
self.channelwise_axes)
670+
self.assertEqual(passed_how_to_quantize.tiled_axes, self.tiled_axes)
671+
self.assertEqual(passed_how_to_quantize.calibration_method,
672+
self.calibration_method)
673+
674+
@patch(
675+
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
676+
)
677+
@patch('qwix.pallas.get_current_rule')
678+
def test_manually_quantize_qwix_activation(self, mock_get_rule,
679+
mock_quantize_act):
680+
"""Test that manually_quantize_qwix_activation calls ptq.quantize_act correctly."""
681+
mock_rule = MagicMock()
682+
mock_rule.act_static_scale = False
683+
mock_get_rule.return_value = mock_rule
684+
rule_name = "test_rule"
685+
686+
quantize_qwix.manually_quantize_qwix_activation(
687+
inputs=self.inputs,
688+
rule_name=rule_name,
689+
qtype=self.qtype,
690+
channelwise_axes=self.channelwise_axes,
691+
tiled_axes=self.tiled_axes,
692+
calibration_method=self.calibration_method)
693+
694+
mock_get_rule.assert_called_once_with(rule_name)
695+
mock_quantize_act.assert_called_once()
696+
697+
args, _ = mock_quantize_act.call_args
698+
passed_inputs, passed_how, passed_rule, passed_act_name = args
699+
700+
self.assertTrue(jnp.array_equal(passed_inputs, self.inputs))
701+
self.assertIsInstance(passed_how, ptq.qarray.HowToQuantize)
702+
self.assertEqual(passed_how.qtype, self.qtype)
703+
self.assertEqual(passed_how.channelwise_axes, self.channelwise_axes)
704+
self.assertEqual(passed_how.tiled_axes, self.tiled_axes)
705+
self.assertEqual(passed_how.calibration_method,
706+
self.calibration_method)
707+
self.assertIs(passed_rule, mock_rule)
708+
self.assertEqual(passed_act_name, "") # act_name is hardcoded to ""
709+
710+
@patch('qwix.pallas.get_current_rule')
711+
def test_manually_quantize_qwix_activation_static_scale_raises_error(
712+
self, mock_get_rule):
713+
"""Test that an assertion is raised if the rule has static scale."""
714+
mock_rule = MagicMock()
715+
mock_rule.act_static_scale = True
716+
mock_get_rule.return_value = mock_rule
717+
718+
with self.assertRaisesRegex(AssertionError,
719+
"Static scale not supported right now"):
720+
quantize_qwix.manually_quantize_qwix_activation(
721+
inputs=self.inputs,
722+
rule_name="any_rule",
723+
qtype=self.qtype,
724+
channelwise_axes=self.channelwise_axes,
725+
tiled_axes=self.tiled_axes,
726+
calibration_method=self.calibration_method)
727+
728+
729+
class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
730+
"""Tests for the get_quant_dtype_from_qwix_config function."""
731+
732+
def setUp(self):
733+
self.mock_vllm_config = MagicMock()
734+
self.mock_vllm_config.additional_config = {}
735+
736+
def test_get_quant_dtype_success(self):
737+
"""Test successful extraction of dtypes from a valid config."""
738+
self.mock_vllm_config.additional_config = {
739+
"quantization": {
740+
"qwix": {
741+
"scale_dtype":
742+
"float16",
743+
"rules": [
744+
{
745+
"module_path": ".*mlp.*",
746+
"weight_qtype": "int4"
747+
},
748+
{
749+
"module_path": ".*",
750+
"weight_qtype": "int8"
751+
},
752+
],
753+
}
754+
}
755+
}
756+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
757+
self.mock_vllm_config)
758+
self.assertEqual(scale_dtype, jnp.float16)
759+
self.assertEqual(quant_dtype, jnp.int8)
760+
761+
def test_get_quant_dtype_default_scale(self):
762+
"""Test that scale_dtype defaults to bfloat16 when not specified."""
763+
self.mock_vllm_config.additional_config = {
764+
"quantization": {
765+
"qwix": {
766+
"rules": [{
767+
"module_path": ".*",
768+
"weight_qtype": "int8"
769+
}]
770+
}
771+
}
772+
}
773+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
774+
self.mock_vllm_config)
775+
self.assertEqual(scale_dtype, jnp.bfloat16)
776+
self.assertEqual(quant_dtype, jnp.int8)
777+
778+
def test_no_quantization_config_returns_defaults(self):
779+
"""Test that default dtypes are returned when config is missing."""
780+
self.mock_vllm_config.additional_config = {}
781+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
782+
self.mock_vllm_config)
783+
self.assertEqual(scale_dtype, jnp.bfloat16)
784+
self.assertIsNone(quant_dtype)
785+
786+
def test_get_quant_dtype_no_wildcard_rule_returns_none(self):
787+
"""Test that quant_dtype is None if no wildcard rule is found."""
788+
self.mock_vllm_config.additional_config = {
789+
"quantization": {
790+
"qwix": {
791+
"rules": [{
792+
"module_path": ".*mlp.*",
793+
"weight_qtype": "int4"
794+
}]
795+
}
796+
}
797+
}
798+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
799+
self.mock_vllm_config)
800+
self.assertEqual(scale_dtype, jnp.bfloat16)
801+
self.assertIsNone(quant_dtype)
802+
803+
def test_get_quant_dtype_wildcard_rule_missing_qtype_raises_error(self):
804+
"""Test that an assertion is raised if the wildcard rule is missing weight_qtype."""
805+
self.mock_vllm_config.additional_config = {
806+
"quantization": {
807+
"qwix": {
808+
"rules": [{
809+
"module_path": ".*"
810+
}]
811+
}
812+
}
813+
}
814+
with self.assertRaisesRegex(AssertionError,
815+
"Quantization dtype not found"):
816+
quantize_qwix.get_quant_dtype_from_qwix_config(
817+
self.mock_vllm_config)
818+
819+
def test_get_quant_dtype_no_rules_key_returns_none(self):
820+
"""Test that quant_dtype is None if 'rules' key is missing."""
821+
self.mock_vllm_config.additional_config = {
822+
"quantization": {
823+
"qwix": {
824+
"scale_dtype": "float16",
825+
}
826+
}
827+
}
828+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
829+
self.mock_vllm_config)
830+
self.assertEqual(scale_dtype, jnp.float16)
831+
self.assertIsNone(quant_dtype)
832+
833+
634834
if __name__ == '__main__':
635835
unittest.main()

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import enum
22
from dataclasses import InitVar, dataclass
33
from functools import partial
4-
from typing import Tuple
4+
from typing import Optional, Tuple
55

66
import jax
77
import jax.numpy as jnp
88
from flax import nnx
99
from flax.typing import Sharding
1010
from jax.sharding import PartitionSpec
1111
from jaxtyping import Float
12+
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
13+
from qwix._src.providers import ptq
1214

1315
from tpu_commons.models.jax.common.base import create_param
1416
from tpu_commons.models.jax.common.layers import FlaxUtils
1517
from tpu_commons.models.jax.common.moe.moe import MoE
18+
from tpu_commons.models.jax.utils.quantization.quantization_utils import (
19+
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
1620

1721
modeling_flax_utils = FlaxUtils()
1822

@@ -141,6 +145,8 @@ class SparseMoE(MoE):
141145
tile_size: tuple[int, int, int] = (128, 64, 128)
142146
use_megablox: bool = False
143147
mesh: jax.sharding.Mesh
148+
# This should be set if and only if you have quantized your model (via Qwix)
149+
quantized_dtype: Optional[jnp.dtype] = None
144150

145151
def __post_init__(self, rngs: nnx.Rngs):
146152
super().__post_init__(rngs)
@@ -348,7 +354,11 @@ def _gmm(self, inputs, kernel, group_sizes):
348354
raise NotImplementedError(
349355
"MegaBlox kernel call is not implemented.")
350356
else:
351-
output = jax.lax.ragged_dot(
357+
inputs = manually_quantize_qwix_activation(
358+
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
359+
"absmax") if self.quantized_dtype else inputs
360+
ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
361+
output = ragged_dot_func(
352362
lhs=inputs,
353363
rhs=kernel,
354364
group_sizes=group_sizes,
@@ -572,12 +582,27 @@ def __call__(self, x_TD: Float):
572582
check_rep=False)(
573583
SparseMoE._distributed_sparse_moe_fwd)
574584

575-
return mapped_moe_fwd(
576-
self,
577-
x_TD,
578-
router_weights_TX,
579-
selected_experts_TX,
580-
self.kernel_gating_EDF.value,
581-
self.kernel_up_proj_EDF.value,
582-
self.kernel_down_proj_EFD.value,
583-
)
585+
kernel_gating_EDF = self.kernel_gating_EDF.value
586+
kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
587+
kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
588+
589+
if self.quantized_dtype:
590+
if not isinstance(kernel_gating_EDF, ptq.WithAux):
591+
kernel_gating_EDF = manually_quantize_qwix_weight(
592+
kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
593+
"absmax")
594+
if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
595+
kernel_up_proj_EDF = manually_quantize_qwix_weight(
596+
kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
597+
"absmax")
598+
if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
599+
kernel_down_proj_EFD = manually_quantize_qwix_weight(
600+
kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
601+
"absmax")
602+
kernel_gating_EDF = kernel_gating_EDF.array
603+
kernel_up_proj_EDF = kernel_up_proj_EDF.array
604+
kernel_down_proj_EFD = kernel_down_proj_EFD.array
605+
606+
return mapped_moe_fwd(self, x_TD, router_weights_TX,
607+
selected_experts_TX, kernel_gating_EDF,
608+
kernel_up_proj_EDF, kernel_down_proj_EFD)

tpu_commons/models/jax/deepseek_v3.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tpu_commons.models.jax.common.moe.moe import MoE
2525
from tpu_commons.models.jax.common.transformer_block import (
2626
SharedExpertsTransformerBlock, TransformerBlock)
27+
from tpu_commons.models.jax.utils.quantization.quantization_utils import \
28+
get_quant_dtype_from_qwix_config
2729
from tpu_commons.models.jax.utils.weight_utils import (get_param,
2830
model_weights_generator,
2931
print_param_info,
@@ -212,6 +214,8 @@ def _create_mla() -> MLA:
212214
activation_ffw_ted=('data', None, None),
213215
edf_sharding=('model', None, None),
214216
efd_sharding=('model', None, None),
217+
quantized_dtype=self.weight_loader.quant_dtype
218+
if self.weight_loader.is_model_quantized else None,
215219
router=router) if is_moe_layer else DenseFFW(
216220
dtype=dtype,
217221
hidden_act=hidden_act,
@@ -453,20 +457,12 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
453457
quantization_type = vllm_config.model_config.hf_config.quantization_config[
454458
"quant_method"]
455459
assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
456-
# NOTE: this will only be used for loading in quantized weights (via Qwix)
457-
qwix_config = vllm_config.additional_config.get(
458-
"quantization", {}).get("qwix", {})
459-
self.scale_dtype = getattr(
460-
jnp, qwix_config.get("scale_dtype", "bfloat16"))
461-
# TODO (jacobplatin): move this out of DeepSeek class to a utility function
462-
for rule in qwix_config.get("rules", []):
463-
if rule.get("module_path") == ".*":
464-
quant_dtype_str = rule.get("weight_qtype", "")
465-
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
466-
self.quant_dtype = getattr(jnp, quant_dtype_str)
467-
logger.info(
468-
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
469-
)
460+
self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
461+
vllm_config)
462+
463+
logger.info(
464+
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
465+
)
470466

471467
quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
472468
"weight_block_size"]
@@ -578,7 +574,17 @@ def _load_individual_weight(self,
578574

579575
# Convert weights from torch into numpy
580576
cast_type = model_weight.value.dtype
581-
weight_np = weight.to(torch.float32).numpy().astype(cast_type)
577+
578+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
579+
580+
if torch_view_type:
581+
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
582+
# raw data as integers before converting to a JAX array.
583+
weight_np = jnp.array(
584+
weight.view(torch_view_type).numpy()).view(cast_type)
585+
else:
586+
raise ValueError(
587+
f"Unsupported dtype for tensor conversion: {cast_type}")
582588

583589
if scale is not None:
584590
scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)

0 commit comments

Comments
 (0)