Skip to content

Commit 1ce4114

Browse files
committed
cleanup
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 3ad0656 commit 1ce4114

File tree

2 files changed

+13
-23
lines changed

2 files changed

+13
-23
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def __init__(
518518
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
519519
logger.debug("WQ = %s", str(self.weight_quant))
520520
self.weight_block_size = self.weight_quant.block_structure
521-
# TODO: self.weight_quant.dynamic
521+
assert self.weight_quant.dynamic is not None
522522
else:
523523
self.weight_block_size = None
524524
self.block_quant = self.weight_block_size is not None
@@ -550,15 +550,9 @@ def __init__(
550550
or self.is_fp8_w8a8_sm100)
551551
self.disable_expert_map = False
552552

553-
def create_weights(
554-
self,
555-
layer: torch.nn.Module,
556-
num_experts: int,
557-
hidden_size: int,
558-
intermediate_size_per_partition: int,
559-
params_dtype: torch.dtype,
560-
**extra_weight_attrs,
561-
):
553+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
554+
hidden_size: int, intermediate_size_per_partition: int,
555+
params_dtype: torch.dtype, **extra_weight_attrs):
562556

563557
layer.intermediate_size_per_partition = intermediate_size_per_partition
564558
layer.hidden_size = hidden_size
@@ -668,6 +662,8 @@ def create_weights(
668662
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
669663
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
670664
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
665+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
666+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
671667

672668
# INPUT_SCALES
673669
if self.static_input_scales:
@@ -690,7 +686,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
690686
# Fp8 moe kernels require a single activation scale.
691687
# We take the max of all the scales in case they differ.
692688
if self.static_input_scales:
693-
# TODO(bnell): Is this assert right?
694689
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
695690
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
696691
raise ValueError(
@@ -793,7 +788,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
793788
device=device,
794789
dtype=torch.int64)
795790

796-
# XXXXXXXXXXXXX
797791
if is_deep_gemm_e8m0_used() and self.block_quant:
798792
assert layer.weight_block_size is not None
799793
# Re-quantise the expert weights so their scales are UE8M0.
@@ -810,11 +804,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
810804
)
811805

812806
# Ensure column-major TMA alignment expected by DeepGEMM.
813-
if _is_col_major(layer.w13_weight_scale):
814-
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
807+
if _is_col_major(layer.w13_weight_scale_inv):
808+
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
815809
layer.w13_weight_scale)
816-
if _is_col_major(layer.w2_weight_scale):
817-
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
810+
if _is_col_major(layer.w2_weight_scale_inv):
811+
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
818812
layer.w2_weight_scale)
819813

820814
def maybe_make_prepare_finalize(

vllm/model_executor/warmup/deep_gemm_warmup.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
7575
and module.quant_method.block_quant):
7676
return False
7777

78-
try:
79-
w, _, block_sizes = _extract_data_from_linear_base_module(module)
80-
return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
81-
and w.shape[0] % block_size == 0
82-
and w.shape[1] % block_size == 0)
83-
except Exception:
84-
return False
78+
w, _, block_sizes = _extract_data_from_linear_base_module(module)
79+
return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
80+
and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)
8581

8682

8783
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:

0 commit comments

Comments
 (0)