Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ fi

# ====================================== Install main package ============================================

# TODO may improve
ARG SLIME_COMMIT=main
RUN git clone https://github.com/THUDM/slime.git /root/slime && \
cd /root/slime && \
git checkout ${SLIME_COMMIT} && \
pip install -e . --no-deps

RUN cd /root/slime/slime/backends/megatron_utils/kernels/int4_qat && \
pip install . --no-build-isolation
137 changes: 128 additions & 9 deletions docker/patch/latest/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,19 @@ index a1885fade..14d692365 100644
moe_sum_reduce_torch_compile(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
index 839463518..7948779aa 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
@@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
- )
+ ) and "zero" not in weight_name
else loaded_weight
)

diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py
index 00bd68755..5a3ca8a67 100644
--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py
Expand Down Expand Up @@ -725,11 +738,87 @@ index 00bd68755..5a3ca8a67 100644
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)

def get_routed_experts(
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
index b4bdc41b3..3b895ff6a 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -442,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
)
is_static = not weight_quant.dynamic

- return is_channel_group and input_quant_none and is_symmetric and is_static
+ return is_channel_group and input_quant_none and is_static

def _get_scheme_from_parts(
self, weight_quant: BaseModel, input_quant: BaseModel
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index c5e5a11fc..dd321fa13 100644
index c5e5a11fc..c46526ecc 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -1016,13 +1016,37 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -30,7 +30,10 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
-from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
+from sglang.srt.layers.quantization.marlin_utils import (
+ marlin_moe_permute_scales,
+ moe_awq_to_marlin_zero_points
+)
from sglang.srt.layers.quantization.utils import (
all_close_1d,
per_tensor_dequantize,
@@ -865,7 +868,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
- assert config.symmetric, "Only symmetric quantization is supported for MoE"
+ self.sym = config.symmetric

if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
@@ -920,7 +923,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):

# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
- load_full_w2 = self.actorder and self.group_size != -1
+ load_full_w2 = (self.actorder != 'static') and self.group_size != -1

if load_full_w2:
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
@@ -968,6 +971,32 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)

+ # add zero param
+ if not self.sym:
+ w13_qzeros = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ num_groups_w13,
+ 2 * intermediate_size_per_partition // self.packed_factor,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_weight_zero_point", w13_qzeros)
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
+
+ w2_qzeros = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ num_groups_w2,
+ hidden_size // self.packed_factor,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_weight_zero_point", w2_qzeros)
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
+
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
@@ -1016,13 +1045,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.a2_scale = None
layer.marlin_state = GPTQMarlinState.REPACK

Expand All @@ -738,11 +827,14 @@ index c5e5a11fc..dd321fa13 100644
+
+ # Force record: these are the target GPTQ shapes for rollback.
+ layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape)
+ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
+ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+ if not self.sym:
+ layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape
+
+ # Also record the shapes of the scales.
+ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
+ layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape)
+ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+ if not self.sym:
+ layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape)
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ # Skip if the layer is already converted to Marlin format to prevent double-packing.
Expand All @@ -769,7 +861,7 @@ index c5e5a11fc..dd321fa13 100644
del new_t

num_experts = layer.w13_weight_g_idx.shape[0]
@@ -1078,7 +1102,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -1078,7 +1134,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[2],
self.num_bits,
)
Expand All @@ -778,7 +870,7 @@ index c5e5a11fc..dd321fa13 100644
marlin_w2_qweight = gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
@@ -1086,7 +1110,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -1086,7 +1142,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed.shape[2],
self.num_bits,
)
Expand All @@ -787,7 +879,7 @@ index c5e5a11fc..dd321fa13 100644
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale,
@@ -1094,7 +1118,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -1094,7 +1150,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale.shape[2],
self.group_size,
)
Expand All @@ -796,13 +888,31 @@ index c5e5a11fc..dd321fa13 100644

marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale,
@@ -1103,7 +1127,22 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -1103,7 +1159,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale.shape[2],
self.group_size,
)
- replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
+ replace_tensor("w2_weight_scale", marlin_w2_scales)
+
+ # Repack zero
+ if not self.sym:
+ marlin_w13_zp = moe_awq_to_marlin_zero_points(
+ layer.w13_weight_zero_point,
+ size_k=layer.w13_weight_zero_point.shape[1],
+ size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor,
+ num_bits=self.num_bits,
+ )
+ replace_tensor("w13_weight_zero_point", marlin_w13_zp)
+
+ marlin_w2_zp = moe_awq_to_marlin_zero_points(
+ layer.w2_weight_zero_point,
+ size_k=layer.w2_weight_zero_point.shape[1],
+ size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor,
+ num_bits=self.num_bits,
+ )
+ replace_tensor("w2_weight_zero_point", marlin_w2_zp)
+
+ layer.is_marlin_converted = True
+
+ def restore_weights_before_loading(self, layer: torch.nn.Module):
Expand All @@ -820,6 +930,15 @@ index c5e5a11fc..dd321fa13 100644

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
@@ -1154,6 +1243,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
+ w1_zeros=layer.w13_weight_zero_point if not self.sym else None,
+ w2_zeros=layer.w2_weight_zero_point if not self.sym else None,
num_bits=self.num_bits,
is_k_full=self.is_k_full,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 480579e01..dd8ca7d4f 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
Expand Down
2 changes: 1 addition & 1 deletion docker/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-dev-20260119a
nightly-dev-20260120a
Loading