diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index e456bfab83d3..9d6b25da7e6d 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -12,7 +12,7 @@ "robertgshaw2/zephyr-7b-beta-channelwise-gptq") REVISION = os.environ.get("REVISION", "main") QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin") -MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89") +MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "80") @pytest.mark.skipif( diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index a43b2e597c1e..de4009d7d04a 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -17,6 +17,7 @@ is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, @@ -134,7 +135,12 @@ def get_quant_method(self, layer: torch.nn.Module, self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - return AWQMoEMethod(self) + if layer.num_experts > 32: + # For MoEs with many experts the moe_wna16 kernel is faster + return MoeWNA16Config.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return AWQMoEMethod(self) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0a9d86b008db..f421dbd2ce2b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,20 +10,18 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearMethodBase, - UnquantizedLinearMethod, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, verify_marlin_supported) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - UnquantizedEmbeddingMethod) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__( - self, - weight_bits: int, - group_size: int, - desc_act: bool, - is_sym: bool, - lm_head_quantized: bool, - dynamic: Dict[str, Dict[str, Union[int, bool]]], - ) -> None: + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool, lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], + full_config: Dict[str, Any]) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -90,6 +83,7 @@ def __init__( self.group_size = group_size self.desc_act = desc_act self.lm_head_quantized = lm_head_quantized + self.full_config = full_config if (weight_bits, is_sym) not in self.TYPE_MAP: raise ValueError("Unsupported quantization config: " @@ -132,7 +126,7 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls(weight_bits, group_size, desc_act, is_sym, - lm_head_quantized, dynamic) + lm_head_quantized, dynamic, config) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -155,12 +149,15 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod", - UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - return GPTQMarlinMoEMethod(self) + if layer.num_experts > 32: + # For MoEs with many experts the moe_wna16 kernel is faster + return MoeWNA16Config.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index b9460e7d7985..30eb04698d81 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -9,13 +9,8 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supports_layer) from vllm.model_executor.utils import set_weight_attrs @@ -37,6 +32,12 @@ def __init__(self, linear_quant_method: str, weight_bits: int, self.linear_quant_method = linear_quant_method self.full_config = full_config self.use_marlin = False + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig) + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) if self.linear_quant_method == "gptq": self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( full_config) @@ -115,6 +116,8 @@ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else capability_tuple.to_int()) + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig awq_min_capability = AWQConfig.get_min_capability() gptq_compatible = quant_method == "gptq" and \ @@ -129,6 +132,13 @@ def get_quant_method(self, layer: torch.nn.Module, if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig) + from vllm.model_executor.layers.quantization.gptq import GPTQConfig + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config(