Skip to content

Commit dc66730

Browse files
committed
Initial commit
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent 280bbb2 commit dc66730

File tree

4 files changed

+113
-25
lines changed

4 files changed

+113
-25
lines changed

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
from dataclasses import InitVar, dataclass
33
from functools import partial
4-
from typing import Tuple
4+
from typing import Optional, Tuple
55

66
import jax
77
import jax.numpy as jnp
@@ -13,6 +13,8 @@
1313
from tpu_commons.models.jax.common.base import create_param
1414
from tpu_commons.models.jax.common.layers import FlaxUtils
1515
from tpu_commons.models.jax.common.moe.moe import MoE
16+
from tpu_commons.models.jax.utils.quantization.quantization_utils import (
17+
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
1618

1719
modeling_flax_utils = FlaxUtils()
1820

@@ -141,6 +143,8 @@ class SparseMoE(MoE):
141143
tile_size: tuple[int, int, int] = (128, 64, 128)
142144
use_megablox: bool = False
143145
mesh: jax.sharding.Mesh
146+
# This should be set if and only if you have quantized your model (via Qwix)
147+
quantized_dtype: Optional[jnp.dtype] = None
144148

145149
def __post_init__(self, rngs: nnx.Rngs):
146150
super().__post_init__(rngs)
@@ -356,6 +360,10 @@ def _gmm(self, inputs, kernel, group_sizes):
356360
raise NotImplementedError(
357361
"MegaBlox kernel call is not implemented.")
358362
else:
363+
inputs = manually_quantize_qwix_activation(
364+
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
365+
"absmax") if self.quantized_dtype else inputs
366+
# TODO: make qwix.ragged_dot
359367
output = jax.lax.ragged_dot(
360368
lhs=inputs,
361369
rhs=kernel,
@@ -583,12 +591,19 @@ def __call__(self, x_TD: Float):
583591
check_rep=False)(
584592
SparseMoE._distributed_sparse_moe_fwd)
585593

586-
return mapped_moe_fwd(
587-
self,
588-
x_TD,
589-
router_weights_TX,
590-
selected_experts_TX,
591-
self.kernel_gating_EDF.value,
592-
self.kernel_up_proj_EDF.value,
593-
self.kernel_down_proj_EFD.value,
594-
)
594+
quantized_kernel_gating_EDF = manually_quantize_qwix_weight(
595+
self.kernel_gating_EDF.value, self.quantized_dtype, [0, 2], {},
596+
"absmax") if self.quantized_dtype else self.kernel_gating_EDF.value
597+
quantized_kernel_up_proj_EDF = manually_quantize_qwix_weight(
598+
self.kernel_up_proj_EDF.value, self.quantized_dtype, [0, 2], {},
599+
"absmax"
600+
) if self.quantized_dtype else self.kernel_up_proj_EDF.value
601+
quantized_kernel_down_proj_EFD = manually_quantize_qwix_weight(
602+
self.kernel_down_proj_EFD.value, self.quantized_dtype, [0, 2], {},
603+
"absmax"
604+
) if self.quantized_dtype else self.kernel_down_proj_EFD.value
605+
606+
return mapped_moe_fwd(self, x_TD, router_weights_TX,
607+
selected_experts_TX, quantized_kernel_gating_EDF,
608+
quantized_kernel_up_proj_EDF,
609+
quantized_kernel_down_proj_EFD)

tpu_commons/models/jax/deepseek_v3.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tpu_commons.models.jax.common.moe.moe import MoE
2525
from tpu_commons.models.jax.common.transformer_block import (
2626
SharedExpertsTransformerBlock, TransformerBlock)
27+
from tpu_commons.models.jax.utils.quantization.quantization_utils import \
28+
get_quant_dtype_from_qwix_config
2729
from tpu_commons.models.jax.utils.weight_utils import (get_param,
2830
model_weights_generator,
2931
print_param_info,
@@ -45,7 +47,7 @@ def __init__(self,
4547
self.vllm_config = vllm_config
4648
self.rng = nnx.Rngs(rng)
4749

48-
num_layers: int = 61
50+
num_layers: int = 4
4951
num_local_experts: int = 256
5052

5153
vocab_size: int = 129280
@@ -211,6 +213,8 @@ def _create_mla() -> MLA:
211213
activation_ffw_ted=('data', None, None),
212214
edf_sharding=('model', None, None),
213215
efd_sharding=('model', None, None),
216+
quantized_dtype=self.weight_loader.quant_dtype
217+
if self.weight_loader.is_model_quantized else None,
214218
router=router) if is_moe_layer else DenseFFW(
215219
dtype=dtype,
216220
hidden_act=hidden_act,
@@ -452,20 +456,12 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
452456
quantization_type = vllm_config.model_config.hf_config.quantization_config[
453457
"quant_method"]
454458
assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
455-
# NOTE: this will only be used for loading in quantized weights (via Qwix)
456-
qwix_config = vllm_config.additional_config.get(
457-
"quantization", {}).get("qwix", {})
458-
self.scale_dtype = getattr(
459-
jnp, qwix_config.get("scale_dtype", "bfloat16"))
460-
# TODO (jacobplatin): move this out of DeepSeek class to a utility function
461-
for rule in qwix_config.get("rules", []):
462-
if rule.get("module_path") == ".*":
463-
quant_dtype_str = rule.get("weight_qtype", "")
464-
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
465-
self.quant_dtype = getattr(jnp, quant_dtype_str)
466-
logger.info(
467-
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
468-
)
459+
self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
460+
vllm_config)
461+
462+
logger.info(
463+
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
464+
)
469465

470466
quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
471467
"weight_block_size"]

tpu_commons/models/jax/utils/quantization/quantization_utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
import jax
88
import jax.numpy as jnp
99
import qwix
10+
import qwix.pallas
1011
import yaml
1112
from flax import nnx
1213
from flax.typing import PRNGKey
1314
from jax.sharding import Mesh, NamedSharding
1415
from jax.sharding import PartitionSpec as P
16+
from qwix._src.core.qarray import QArray
17+
from qwix._src.providers import ptq
1518

1619
if TYPE_CHECKING:
1720
from vllm.config import VllmConfig
@@ -507,3 +510,76 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
507510
if hasattr(model, 'initialize_cache'):
508511
model.initialize_cache()
509512
logger.info("Done initializing Qwix-quantized model with random weights")
513+
514+
515+
def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
516+
channelwise_axes: List[int],
517+
tiled_axes: dict,
518+
calibration_method: str) -> QArray:
519+
"""
520+
"""
521+
# TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
522+
how_to_quantize = ptq.qarray.HowToQuantize(
523+
qtype=qtype,
524+
channelwise_axes=channelwise_axes,
525+
tiled_axes=tiled_axes,
526+
calibration_method=calibration_method)
527+
528+
return ptq._create_quantized_param(weight, how_to_quantize)
529+
530+
531+
def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
532+
qtype: jnp.dtype,
533+
channelwise_axes: List[int],
534+
tiled_axes: dict,
535+
calibration_method: str) -> QArray:
536+
"""
537+
Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
538+
DeepSeek MoE case currently.
539+
540+
Args:
541+
inputs: The activation tensor to quantize.
542+
rule_name: The name of the quantization rule to use.
543+
qtype: The quantization type.
544+
channelwise_axes: The channelwise axes to quantize.
545+
tiled_axes: The tiled axes to quantize.
546+
calibration_method: The calibration method to use.
547+
548+
Returns:
549+
The quantized activation tensor.
550+
"""
551+
rule = qwix.pallas.get_current_rule(rule_name)
552+
lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
553+
channelwise_axes=channelwise_axes,
554+
tiled_axes=tiled_axes,
555+
calibration_method=calibration_method)
556+
assert not rule.act_static_scale, "Static scale not supported right now"
557+
558+
# channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
559+
# for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
560+
# TODO (jacobplatin): add support for `act_name`
561+
return ptq._quantize_act(inputs, lhs_how, rule, "")
562+
563+
564+
def get_quant_dtype_from_qwix_config(
565+
vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
566+
"""
567+
Gets the quantization dtype from the Qwix config.
568+
569+
Args:
570+
vllm_config: The VllmConfig object.
571+
572+
Returns:
573+
A tuple of the scale dtype and quant dtype.
574+
"""
575+
qwix_config = vllm_config.additional_config.get("quantization",
576+
{}).get("qwix", {})
577+
scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
578+
quant_dtype = None
579+
# TODO (jacobplatin): this needs to be much more robust
580+
for rule in qwix_config.get("rules", []):
581+
if rule.get("module_path") == ".*":
582+
quant_dtype_str = rule.get("weight_qtype", "")
583+
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
584+
quant_dtype = getattr(jnp, quant_dtype_str)
585+
return scale_dtype, quant_dtype

tpu_commons/models/jax/utils/weight_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def get_model_weights_files(
111111
)
112112

113113
weights_files.sort()
114+
weights_files = weights_files[:4] + [weights_files[-4]]
114115
return weights_files
115116

116117

0 commit comments

Comments
 (0)