Skip to content

Commit c1c5ace

Browse files
committed
Initial commit
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent 280bbb2 commit c1c5ace

File tree

4 files changed

+124
-26
lines changed

4 files changed

+124
-26
lines changed

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import enum
22
from dataclasses import InitVar, dataclass
33
from functools import partial
4-
from typing import Tuple
4+
from typing import Optional, Tuple
55

66
import jax
77
import jax.numpy as jnp
88
from flax import nnx
99
from flax.typing import Sharding
1010
from jax.sharding import PartitionSpec
1111
from jaxtyping import Float
12+
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
13+
from qwix._src.providers import ptq
1214

1315
from tpu_commons.models.jax.common.base import create_param
1416
from tpu_commons.models.jax.common.layers import FlaxUtils
1517
from tpu_commons.models.jax.common.moe.moe import MoE
18+
from tpu_commons.models.jax.utils.quantization.quantization_utils import (
19+
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
1620

1721
modeling_flax_utils = FlaxUtils()
1822

@@ -141,6 +145,8 @@ class SparseMoE(MoE):
141145
tile_size: tuple[int, int, int] = (128, 64, 128)
142146
use_megablox: bool = False
143147
mesh: jax.sharding.Mesh
148+
# This should be set if and only if you have quantized your model (via Qwix)
149+
quantized_dtype: Optional[jnp.dtype] = None
144150

145151
def __post_init__(self, rngs: nnx.Rngs):
146152
super().__post_init__(rngs)
@@ -356,7 +362,11 @@ def _gmm(self, inputs, kernel, group_sizes):
356362
raise NotImplementedError(
357363
"MegaBlox kernel call is not implemented.")
358364
else:
359-
output = jax.lax.ragged_dot(
365+
inputs = manually_quantize_qwix_activation(
366+
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
367+
"absmax") if self.quantized_dtype else inputs
368+
ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
369+
output = ragged_dot_func(
360370
lhs=inputs,
361371
rhs=kernel,
362372
group_sizes=group_sizes,
@@ -583,12 +593,27 @@ def __call__(self, x_TD: Float):
583593
check_rep=False)(
584594
SparseMoE._distributed_sparse_moe_fwd)
585595

586-
return mapped_moe_fwd(
587-
self,
588-
x_TD,
589-
router_weights_TX,
590-
selected_experts_TX,
591-
self.kernel_gating_EDF.value,
592-
self.kernel_up_proj_EDF.value,
593-
self.kernel_down_proj_EFD.value,
594-
)
596+
kernel_gating_EDF = self.kernel_gating_EDF.value
597+
kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
598+
kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
599+
600+
if self.quantized_dtype:
601+
if not isinstance(kernel_gating_EDF, ptq.WithAux):
602+
kernel_gating_EDF = manually_quantize_qwix_weight(
603+
kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
604+
"absmax")
605+
if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
606+
kernel_up_proj_EDF = manually_quantize_qwix_weight(
607+
kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
608+
"absmax")
609+
if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
610+
kernel_down_proj_EFD = manually_quantize_qwix_weight(
611+
kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
612+
"absmax")
613+
kernel_gating_EDF = kernel_gating_EDF.array
614+
kernel_up_proj_EDF = kernel_up_proj_EDF.array
615+
kernel_down_proj_EFD = kernel_down_proj_EFD.array
616+
617+
return mapped_moe_fwd(self, x_TD, router_weights_TX,
618+
selected_experts_TX, kernel_gating_EDF,
619+
kernel_up_proj_EDF, kernel_down_proj_EFD)

tpu_commons/models/jax/deepseek_v3.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tpu_commons.models.jax.common.moe.moe import MoE
2525
from tpu_commons.models.jax.common.transformer_block import (
2626
SharedExpertsTransformerBlock, TransformerBlock)
27+
from tpu_commons.models.jax.utils.quantization.quantization_utils import \
28+
get_quant_dtype_from_qwix_config
2729
from tpu_commons.models.jax.utils.weight_utils import (get_param,
2830
model_weights_generator,
2931
print_param_info,
@@ -45,7 +47,7 @@ def __init__(self,
4547
self.vllm_config = vllm_config
4648
self.rng = nnx.Rngs(rng)
4749

48-
num_layers: int = 61
50+
num_layers: int = 4
4951
num_local_experts: int = 256
5052

5153
vocab_size: int = 129280
@@ -211,6 +213,8 @@ def _create_mla() -> MLA:
211213
activation_ffw_ted=('data', None, None),
212214
edf_sharding=('model', None, None),
213215
efd_sharding=('model', None, None),
216+
quantized_dtype=self.weight_loader.quant_dtype
217+
if self.weight_loader.is_model_quantized else None,
214218
router=router) if is_moe_layer else DenseFFW(
215219
dtype=dtype,
216220
hidden_act=hidden_act,
@@ -452,20 +456,12 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
452456
quantization_type = vllm_config.model_config.hf_config.quantization_config[
453457
"quant_method"]
454458
assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
455-
# NOTE: this will only be used for loading in quantized weights (via Qwix)
456-
qwix_config = vllm_config.additional_config.get(
457-
"quantization", {}).get("qwix", {})
458-
self.scale_dtype = getattr(
459-
jnp, qwix_config.get("scale_dtype", "bfloat16"))
460-
# TODO (jacobplatin): move this out of DeepSeek class to a utility function
461-
for rule in qwix_config.get("rules", []):
462-
if rule.get("module_path") == ".*":
463-
quant_dtype_str = rule.get("weight_qtype", "")
464-
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
465-
self.quant_dtype = getattr(jnp, quant_dtype_str)
466-
logger.info(
467-
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
468-
)
459+
self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
460+
vllm_config)
461+
462+
logger.info(
463+
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
464+
)
469465

470466
quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
471467
"weight_block_size"]

tpu_commons/models/jax/utils/quantization/quantization_utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
import jax
88
import jax.numpy as jnp
99
import qwix
10+
import qwix.pallas
1011
import yaml
1112
from flax import nnx
1213
from flax.typing import PRNGKey
1314
from jax.sharding import Mesh, NamedSharding
1415
from jax.sharding import PartitionSpec as P
16+
from qwix._src.core.qarray import QArray
17+
from qwix._src.providers import ptq
1518

1619
if TYPE_CHECKING:
1720
from vllm.config import VllmConfig
@@ -507,3 +510,76 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
507510
if hasattr(model, 'initialize_cache'):
508511
model.initialize_cache()
509512
logger.info("Done initializing Qwix-quantized model with random weights")
513+
514+
515+
def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
516+
channelwise_axes: List[int],
517+
tiled_axes: dict,
518+
calibration_method: str) -> QArray:
519+
"""
520+
"""
521+
# TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
522+
how_to_quantize = ptq.qarray.HowToQuantize(
523+
qtype=qtype,
524+
channelwise_axes=channelwise_axes,
525+
tiled_axes=tiled_axes,
526+
calibration_method=calibration_method)
527+
528+
return ptq._create_quantized_param(weight, how_to_quantize)
529+
530+
531+
def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
532+
qtype: jnp.dtype,
533+
channelwise_axes: List[int],
534+
tiled_axes: dict,
535+
calibration_method: str) -> QArray:
536+
"""
537+
Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
538+
DeepSeek MoE case currently.
539+
540+
Args:
541+
inputs: The activation tensor to quantize.
542+
rule_name: The name of the quantization rule to use.
543+
qtype: The quantization type.
544+
channelwise_axes: The channelwise axes to quantize.
545+
tiled_axes: The tiled axes to quantize.
546+
calibration_method: The calibration method to use.
547+
548+
Returns:
549+
The quantized activation tensor.
550+
"""
551+
rule = qwix.pallas.get_current_rule(rule_name)
552+
lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
553+
channelwise_axes=channelwise_axes,
554+
tiled_axes=tiled_axes,
555+
calibration_method=calibration_method)
556+
assert not rule.act_static_scale, "Static scale not supported right now"
557+
558+
# channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
559+
# for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
560+
# TODO (jacobplatin): add support for `act_name`
561+
return ptq._quantize_act(inputs, lhs_how, rule, "")
562+
563+
564+
def get_quant_dtype_from_qwix_config(
565+
vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
566+
"""
567+
Gets the quantization dtype from the Qwix config.
568+
569+
Args:
570+
vllm_config: The VllmConfig object.
571+
572+
Returns:
573+
A tuple of the scale dtype and quant dtype.
574+
"""
575+
qwix_config = vllm_config.additional_config.get("quantization",
576+
{}).get("qwix", {})
577+
scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
578+
quant_dtype = None
579+
# TODO (jacobplatin): this needs to be much more robust
580+
for rule in qwix_config.get("rules", []):
581+
if rule.get("module_path") == ".*":
582+
quant_dtype_str = rule.get("weight_qtype", "")
583+
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
584+
quant_dtype = getattr(jnp, quant_dtype_str)
585+
return scale_dtype, quant_dtype

tpu_commons/models/jax/utils/weight_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def get_model_weights_files(
111111
)
112112

113113
weights_files.sort()
114+
weights_files = weights_files[:4] + [weights_files[-4]]
114115
return weights_files
115116

116117

0 commit comments

Comments
 (0)