diff --git a/docker/Dockerfile b/docker/Dockerfile index 98c4e85c8..113796997 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -108,9 +108,11 @@ fi # ====================================== Install main package ============================================ -# TODO may improve ARG SLIME_COMMIT=main RUN git clone https://github.com/THUDM/slime.git /root/slime && \ cd /root/slime && \ git checkout ${SLIME_COMMIT} && \ pip install -e . --no-deps + +RUN cd /root/slime/slime/backends/megatron_utils/kernels/int4_qat && \ + pip install . --no-build-isolation diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 3440f7128..1b99c73e6 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -654,6 +654,19 @@ index a1885fade..14d692365 100644 moe_sum_reduce_torch_compile( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +index 839463518..7948779aa 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +@@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module): + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ] +- ) ++ ) and "zero" not in weight_name + else loaded_weight + ) + diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 00bd68755..5a3ca8a67 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -725,11 +738,87 @@ index 00bd68755..5a3ca8a67 100644 self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) def get_routed_experts( +diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +index b4bdc41b3..3b895ff6a 100644 +--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py ++++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +@@ -442,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig): + ) + is_static = not weight_quant.dynamic + +- return is_channel_group and input_quant_none and is_symmetric and is_static ++ return is_channel_group and input_quant_none and is_static + + def _get_scheme_from_parts( + self, weight_quant: BaseModel, input_quant: BaseModel diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py -index c5e5a11fc..dd321fa13 100644 +index c5e5a11fc..c46526ecc 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py -@@ -1016,13 +1016,37 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): +@@ -30,7 +30,10 @@ from sglang.srt.layers.quantization.fp8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, + ) + from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack +-from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales ++from sglang.srt.layers.quantization.marlin_utils import ( ++ marlin_moe_permute_scales, ++ moe_awq_to_marlin_zero_points ++) + from sglang.srt.layers.quantization.utils import ( + all_close_1d, + per_tensor_dequantize, +@@ -865,7 +868,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder +- assert config.symmetric, "Only symmetric quantization is supported for MoE" ++ self.sym = config.symmetric + + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value +@@ -920,7 +923,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales +- load_full_w2 = self.actorder and self.group_size != -1 ++ load_full_w2 = (self.actorder != 'static') and self.group_size != -1 + + if load_full_w2: + w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size +@@ -968,6 +971,32 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + ++ # add zero param ++ if not self.sym: ++ w13_qzeros = torch.nn.Parameter( ++ torch.empty( ++ num_experts, ++ num_groups_w13, ++ 2 * intermediate_size_per_partition // self.packed_factor, ++ dtype=torch.int32, ++ ), ++ requires_grad=False, ++ ) ++ layer.register_parameter("w13_weight_zero_point", w13_qzeros) ++ set_weight_attrs(w13_qzeros, extra_weight_attrs) ++ ++ w2_qzeros = torch.nn.Parameter( ++ torch.empty( ++ num_experts, ++ num_groups_w2, ++ hidden_size // self.packed_factor, ++ dtype=torch.int32, ++ ), ++ requires_grad=False, ++ ) ++ layer.register_parameter("w2_weight_zero_point", w2_qzeros) ++ set_weight_attrs(w2_qzeros, extra_weight_attrs) ++ + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, +@@ -1016,13 +1045,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.a2_scale = None layer.marlin_state = GPTQMarlinState.REPACK @@ -738,11 +827,14 @@ index c5e5a11fc..dd321fa13 100644 + + # Force record: these are the target GPTQ shapes for rollback. + layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) -+ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) ++ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) ++ if not self.sym: ++ layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape + -+ # Also record the shapes of the scales. ++ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) + layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) -+ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) ++ if not self.sym: ++ layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Skip if the layer is already converted to Marlin format to prevent double-packing. @@ -769,7 +861,7 @@ index c5e5a11fc..dd321fa13 100644 del new_t num_experts = layer.w13_weight_g_idx.shape[0] -@@ -1078,7 +1102,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): +@@ -1078,7 +1134,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_packed.shape[2], self.num_bits, ) @@ -778,7 +870,7 @@ index c5e5a11fc..dd321fa13 100644 marlin_w2_qweight = gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, -@@ -1086,7 +1110,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): +@@ -1086,7 +1142,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_packed.shape[2], self.num_bits, ) @@ -787,7 +879,7 @@ index c5e5a11fc..dd321fa13 100644 # Repack scales marlin_w13_scales = marlin_moe_permute_scales( layer.w13_weight_scale, -@@ -1094,7 +1118,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): +@@ -1094,7 +1150,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_scale.shape[2], self.group_size, ) @@ -796,13 +888,31 @@ index c5e5a11fc..dd321fa13 100644 marlin_w2_scales = marlin_moe_permute_scales( layer.w2_weight_scale, -@@ -1103,7 +1127,22 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): +@@ -1103,7 +1159,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_scale.shape[2], self.group_size, ) - replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + replace_tensor("w2_weight_scale", marlin_w2_scales) + ++ # Repack zero ++ if not self.sym: ++ marlin_w13_zp = moe_awq_to_marlin_zero_points( ++ layer.w13_weight_zero_point, ++ size_k=layer.w13_weight_zero_point.shape[1], ++ size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor, ++ num_bits=self.num_bits, ++ ) ++ replace_tensor("w13_weight_zero_point", marlin_w13_zp) ++ ++ marlin_w2_zp = moe_awq_to_marlin_zero_points( ++ layer.w2_weight_zero_point, ++ size_k=layer.w2_weight_zero_point.shape[1], ++ size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor, ++ num_bits=self.num_bits, ++ ) ++ replace_tensor("w2_weight_zero_point", marlin_w2_zp) ++ + layer.is_marlin_converted = True + + def restore_weights_before_loading(self, layer: torch.nn.Module): @@ -820,6 +930,15 @@ index c5e5a11fc..dd321fa13 100644 def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig +@@ -1154,6 +1243,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, ++ w1_zeros=layer.w13_weight_zero_point if not self.sym else None, ++ w2_zeros=layer.w2_weight_zero_point if not self.sym else None, + num_bits=self.num_bits, + is_k_full=self.is_k_full, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 480579e01..dd8ca7d4f 100644 --- a/python/sglang/srt/layers/rotary_embedding.py diff --git a/docker/version.txt b/docker/version.txt index d9fab5e2c..bb1b580ab 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20260119a +nightly-dev-20260120a diff --git a/slime/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py b/slime/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py index c1109743f..026c4424c 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py +++ b/slime/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py @@ -1,9 +1,10 @@ import logging import math import re -from typing import Literal +import fake_int4_quant_cuda import torch +import torch.nn as nn logger = logging.getLogger(__name__) @@ -11,32 +12,130 @@ __all__ = ["quantize_params_compressed_tensors"] +class WQLinear_GEMM(nn.Module): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + self.training = training + + # quick sanity check (make sure alignment) + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.w_bit) == 0 + + self.register_buffer( + "qweight", + torch.zeros( + (in_features, out_features // (32 // self.w_bit)), + dtype=torch.int32, + device=dev, + ), + ) + self.register_buffer( + "qzeros", + torch.zeros( + (in_features // self.group_size, out_features // (32 // self.w_bit)), + dtype=torch.int32, + device=dev, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (in_features // self.group_size, out_features), + dtype=torch.float16, + device=dev, + ), + ) + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + device=dev, + ), + ) + else: + self.bias = None + + @classmethod + def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): + awq_linear = cls( + w_bit, + group_size, + linear.in_features, + linear.out_features, + linear.bias is not None, + linear.weight.device, + ) + if init_only: # just prepare for loading sd + return awq_linear + + # need scales and zeros info for real quantization + assert scales is not None and zeros is not None + + awq_linear.scales = scales.clone().half() + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() + + pack_num = 32 // awq_linear.w_bit + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + repeat_scales = scales.to(device).t().repeat_interleave(group_size, 1) + if isinstance(zeros, torch.Tensor): + repeat_zeros = zeros.to(device).t().repeat_interleave(group_size, 1) + else: + repeat_zeros = zeros + intweight = torch.round(linear.weight.to(device) / repeat_scales + repeat_zeros).to(torch.int).t().contiguous() + intweight = intweight.to(dtype=torch.int32) + del repeat_scales + + intweight = intweight.reshape(-1, intweight.shape[1] // pack_num, pack_num) + + new_order_map = torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], device=device) * awq_linear.w_bit + intweight = intweight << new_order_map + intweight = torch.sum(intweight, dim=-1).to(torch.int32) + awq_linear.qweight = intweight + + if isinstance(zeros, torch.Tensor): + zeros = zeros.to(dtype=torch.int32, device=device) + zeros = zeros.reshape(-1, zeros.shape[1] // pack_num, pack_num) + zeros = zeros << new_order_map + qzeros = torch.sum(zeros, dim=-1).to(torch.int32) + + else: + value = 0 + for i in range(pack_num): + value |= zeros << (i * awq_linear.w_bit) + qzeros = ( + torch.ones( + (scales.shape[0], scales.shape[1] // pack_num), + dtype=torch.int32, + device=device, + ) + * value + ) + + awq_linear.qzeros = qzeros + + return awq_linear + + def pack_to_int32( - value: torch.Tensor, - num_bits: int, - packed_dim: Literal[0] | Literal[1] = 1, -) -> torch.Tensor: - """ - Packs a tensor of quantized weights stored in int8 into int32s with padding - - Pseudocode: - 1. Shift wrt num_bits to convert to unsigned. num_bits=8 - [1,2] -> [129, 130] - 2. Pad to fill in 32 bits - [129, 130] -> [129, 130, 0, 0] - 3. convert to binary align in order - [129, 130, 0, 0] -> 00000000 00000000 10000010 10000001 - 4. convert aligned binary to number - 00000000000000001000001010000001 -> 33409 - 5. covert back to uint32 - 33409 -> 33409 - - :param value: tensor to pack - :param num_bits: number of bits used to store underlying data, must be at least 1 - :returns: packed int32 tensor - """ - if value.dtype is not torch.int8: - raise ValueError("Tensor must be quantized to torch.int8 before packing") + value, + num_bits, + packed_dim=1, + sym=False, +): + # if value.dtype is not torch.int8: + # raise ValueError("Tensor must be quantized to torch.int8 before packing") if num_bits > 8: raise ValueError("Packing is only supported for less than 8 bits") @@ -45,8 +144,9 @@ def pack_to_int32( raise ValueError(f"num_bits must be at least 1, got {num_bits}") # Convert to unsigned range for packing, matching quantization offset - offset = 1 << (num_bits - 1) - value = (value + offset).to(torch.uint8) + if sym: + offset = 1 << (num_bits - 1) + value = (value + offset).to(torch.uint8) device = value.device pack_factor = 32 // num_bits @@ -74,60 +174,89 @@ def pack_to_int32( return packed -def pack_int4_to_int32(q_weight: torch.Tensor) -> torch.Tensor: - """ - pack int4 to int32 - Args: - q_weight: [N, K] tensor, dtype=int8 or uint8 - Returns: - packed: [N, K // 8] tensor, dtype=int32 - """ - return pack_to_int32(q_weight, 4, -1) - - -def int4_block_quantize(x: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - De-quantized = Scale * Quantized (Zero Point is always 0) - """ - N, K = x.shape - if group_size == -1: - group_size = K - - # Padding - if K % group_size != 0: - import torch.nn.functional as F - - x = F.pad(x, (0, group_size - (K % group_size))) - N, K = x.shape - - num_groups = K // group_size - x_reshaped = x.float().view(N, num_groups, group_size) - - # ========================================================= - # 1. Scale - # Range: [-7, 7] -> dividing by 7.0 - # ========================================================= - x_abs_max = x_reshaped.abs().amax(dim=-1, keepdim=True) - scale = x_abs_max / 7.0 - scale = scale.clamp(min=1e-5) - - # ========================================================= - # 2. Quantize - # ========================================================= - x_int_sym = (x_reshaped / scale).round().clamp(-8, 7) - - out = x_int_sym.to(torch.int8) - - # ========================================================= - # 3. Zero Point - # ========================================================= - zero_point = torch.zeros_like(scale) - out = out.view(N, K) - - scale_out = scale.squeeze(-1).contiguous() - zero_out = zero_point.squeeze(-1).contiguous() - - return out, scale_out, zero_out +def round_to_quantized_type_dtype( + tensor, + dtype, + cast_to_original_dtype=False, +): + original_dtype = tensor.dtype + iinfo = torch.iinfo(dtype) + rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)).to(dtype) + if cast_to_original_dtype: + return rounded.to(original_dtype) + return rounded + + +@torch.no_grad() +def quantize( + x, + scale, + zero_point, + dtype=torch.int8, +): + group_size = x.shape[-1] // scale.shape[-1] + output_dtype = dtype + output = torch.zeros_like(x).to(output_dtype) + + reshaped_dims = ( + math.ceil(x.shape[-1] / group_size), + group_size, + ) + x = x.unflatten(-1, reshaped_dims) + + scaled = x / scale.unsqueeze(-1) + + if zero_point is not None: + zero_point = zero_point.unsqueeze(-1) + scaled += zero_point.to(x.dtype) + + # clamp and round + output = round_to_quantized_type_dtype(tensor=scaled, dtype=dtype) + + output = output.flatten(start_dim=-2) + output = output.to(output_dtype) + + return output + + +def if_quant(name, patterns): + for pattern in patterns: + if re.search(pattern, name): + return True + return False + + +def pack_layer(weight, group_size, sym=True): + w, scale, zp = fake_int4_quant_cuda.fake_int4_quant_cuda(weight, (1, group_size), sym) + w = w.view(weight.shape[0], 1, weight.shape[1] // group_size, group_size) + scale = scale.view(weight.shape[0], 1, weight.shape[1] // group_size, 1) + zp = zp.view(weight.shape[0], 1, weight.shape[1] // group_size, 1) + if sym: + w = w * scale + else: + w = (w - zp) * scale + w = w.view(weight.shape) + scale = scale.view(weight.shape[0], -1).contiguous() + if not sym: + zp = zp.view(weight.shape[0], -1) + zeros = zp.t().contiguous().to(torch.float32) + zeros = zeros.to(dtype=torch.int32, device=w.device) + zeros = zeros.reshape(-1, zeros.shape[1] // 8, 8) + new_order_map = torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], device=zeros.device) * 4 + zeros = zeros << new_order_map + packed_zp = torch.sum(zeros, dim=-1).to(torch.int32) + else: + zp = None + packed_zp = None + + quantized_weight = quantize( + x=w, + scale=scale, + zero_point=zp, + dtype=torch.int8 if sym else torch.uint8, + ) + packed_weight = pack_to_int32(quantized_weight, 4, sym=sym) + return packed_weight, scale, packed_zp def quantize_params_compressed_tensors(converted_named_params, quantization_config): @@ -147,45 +276,16 @@ def quantize_params_compressed_tensors(converted_named_params, quantization_conf results.append((name, param)) continue - input_tensor = param.view(-1, param.shape[-1]) if param.dim() > 2 else param - - if group_size != -1 and input_tensor.shape[-1] < group_size: - logger.warning(f"Skipping {name}, K-dim {input_tensor.shape[-1]} < group_size") - results.append((name, param)) - continue - - results.extend(_quantize_param_int4(name, input_tensor, group_size, param.shape, is_symmetric)) # origin shape + qw, s, zp = pack_layer(param, group_size, is_symmetric) + qweight_name = name.replace(".weight", ".weight_packed") + scale_name = name.replace(".weight", ".weight_scale") + weight_shape = torch.tensor(param.shape, dtype=torch.int32, device="cuda") + weight_shape_name = name.replace(".weight", ".weight_shape") + if zp is not None: + zp_name = name.replace(".weight", ".weight_zero_point") + results.append((zp_name, zp)) + results.append((qweight_name, qw)) + results.append((scale_name, s)) + results.append((weight_shape_name, weight_shape)) return results - - -def _quantize_param_int4(name: str, weight: torch.Tensor, group_size: int, shape: torch.Tensor, is_symmetric: bool): - """ - Wraps the quantization function, handles renaming and packing. - """ - base_name = name.replace(".weight", "") - - new_base_name = base_name - - original_dtype = weight.dtype - - if group_size == -1: - group_size = weight.shape[1] - elif weight.shape[1] % group_size != 0: - logger.warning( - f"Weight {name} with shape {weight.shape} has K-dimension " - f"not divisible by group_size {group_size}. Skipping." - ) - return [(name, weight.to(original_dtype))] - - q_weight, scales, zeros = int4_block_quantize(weight, group_size) - - packed_q_weight = pack_int4_to_int32(q_weight) - - qweight_name = f"{new_base_name}.weight_packed" - scales_name = f"{new_base_name}.weight_scale" - qweight_shape = f"{new_base_name}.weight_shape" - - q_shape = torch.tensor(shape, dtype=torch.int32, device="cuda") - - return [(qweight_name, packed_q_weight), (scales_name, scales.to(original_dtype)), (qweight_shape, q_shape)]