diff --git a/vllm/envs.py b/vllm/envs.py index 089a39d..4f710b7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -51,7 +51,7 @@ CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False - + VLLM_TEST_FORCE_FP8_MARLIN: bool = False def get_default_cache_root(): return os.getenv( @@ -341,6 +341,12 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in ("1", "true")), + # If set, forces FP8 Marlin to be used for FP8 quantization regardless + # of the hardware support for FP8 compute. + "VLLM_TEST_FORCE_FP8_MARLIN": + lambda: + (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c829cb8..ce76bf3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -4,6 +4,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase @@ -118,7 +119,7 @@ def __init__(self, quant_config: Fp8Config): # kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN def create_weights( self, @@ -173,6 +174,13 @@ def process_weights_after_loading(self, layer: Module) -> None: if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + assert weight_scale.numel() == 1 + weight_scale = convert_to_channelwise( + weight_scale.expand(len(layer.logical_widths)), + layer.logical_widths) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False)