diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml new file mode 100644 index 0000000000000..e40f42a17c18f --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.752 + - name: "exact_match,flexible-extract" + value: 0.752 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml index 02668702b83af..7a89e8e0c76f2 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml @@ -1,4 +1,4 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" tasks: - name: "gsm8k" diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml new file mode 100644 index 0000000000000..bc29002985969 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.728 + - name: "exact_match,flexible-extract" + value: 0.728 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 273c5482db264..3300ca64f44b8 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,2 +1,4 @@ Meta-Llama-3-8B-Instruct.yaml Meta-Llama-3-8B-Instruct-FP8.yaml +Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index a2876bade8893..933733e9c1edf 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 975841dad1c29..7fdce7b53bd7f 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -24,7 +24,8 @@ def launch_lm_eval(eval_config): model_args = f"pretrained={eval_config['model_name']}," \ - f"tensor_parallel_size={TP_SIZE}" + f"tensor_parallel_size={TP_SIZE}," \ + f"add_bos_token=true" results = lm_eval.simple_evaluate( model="vllm", diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 4cdda97dc728d..96223a247657b 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -9,7 +9,8 @@ from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8, CompressedTensorsWNA16) + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationType) @@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): CompressedTensorsLinearMethod) assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) assert qkv_proj.scheme.strategy == strategy assert qkv_proj.scheme.is_static_input_scheme - expected_type = (torch.int8 if quant_type == QuantizationType.INT else - torch.float8_e4m3fn) + expected_type = torch.int8 assert qkv_proj.weight.dtype is expected_type assert o_proj.weight.dtype is expected_type @@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) assert not qkv_proj.scheme.is_static_input_scheme assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 @@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner): sampling_params = SamplingParams() output = llm.generate("Hello world!", sampling_params=sampling_params) assert output + + +def test_compressed_tensors_fp8(vllm_runner): + model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" + with vllm_runner(model_path) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) + assert qkv_proj.weight.dtype is torch.float8_e4m3fn + assert qkv_proj.input_scale.dtype is torch.float32 + assert qkv_proj.weight_scale.dtype is torch.float32 + # should be scalars after processing + assert len(qkv_proj.input_scale.shape) == 0 + assert len(qkv_proj.weight_scale.shape) == 0 + + sampling_params = SamplingParams() + output = llm.generate("Hello world!", sampling_params=sampling_params) + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 8ca486d95941d..c711fd14c668c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -9,10 +9,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8, CompressedTensorsWNA16) + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, - find_first_name_or_class_match) + QuantizationType, find_first_name_or_class_match) from vllm.platforms import current_platform @@ -117,6 +118,40 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, return is_8_bits and is_token and is_symmetric and is_dynamic + def _is_fp8_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights and activations quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm we have floating points. + if not (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT): + return False + + # Confirm weight scheme is supported. + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_weight = ( + weight_quant.strategy == QuantizationStrategy.TENSOR) + if not (is_symmetric_weight and is_static_weight + and is_per_tensor_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.dynamic: + return True + + # Confirm activation scheme is supported. + is_symmetric_activation = input_quant.symmetric + is_per_tensor_activation = ( + input_quant.strategy == QuantizationStrategy.TENSOR) + if not (is_symmetric_activation and is_per_tensor_activation): + return False + + # All conditions satisfied. + return True + def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None @@ -147,14 +182,21 @@ def _get_schema(self, weight_quant: BaseModel, strategy=weight_quant.strategy, group_size=weight_quant.group_size) - if self.quant_format == CompressionFormat.int_quantized.value: + if (self.quant_format == CompressionFormat.int_quantized.value or + self.quant_format == CompressionFormat.float_quantized.value): + if self._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8( + input_dynamic=input_quant.dynamic) + if self._is_static_tensor_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8(strategy=weight_quant.strategy, - is_static_input_scheme=True) + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True) if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8(strategy=weight_quant.strategy, - is_static_input_scheme=False) + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False) raise NotImplementedError( "No compressed-tensors compatible scheme was found.") @@ -187,7 +229,7 @@ def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - return layer.scheme.process_weights_after_loading(layer) + layer.scheme.process_weights_after_loading(layer) def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 720b8c263298c..dd94c49827f62 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,8 +1,19 @@ -from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 -from .compressed_tensors_unquantized import ( # noqa: F401 - CompressedTensorsUnquantized) -from .compressed_tensors_w4a16_24 import ( # noqa: F401 - W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) -from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401 -from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401 -from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401 +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_unquantized import CompressedTensorsUnquantized +from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 +from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 +from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, + CompressedTensorsWNA16) + +__all__ = [ + "CompressedTensorsScheme", + "CompressedTensorsUnquantized", + "CompressedTensorsWNA16", + "CompressedTensorsW4A16Sparse24", + "CompressedTensorsW8A8Int8", + "CompressedTensorsW8A8Fp8", + "WNA16_SUPPORTED_BITS", + "W4A16SPARSE24_SUPPORTED_BITS", +] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py deleted file mode 100644 index dffe2a284458f..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Callable, List, Tuple, Union - -import torch -from torch.nn import Parameter - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationStrategy) -from vllm.model_executor.utils import set_weight_attrs - - -class CompressedTensorsW8A8(CompressedTensorsScheme): - - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy - self.is_static_input_scheme = is_static_input_scheme - - # Cutlass kernels support only per-tensor and per-channel cases. - # So if we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), we convert to the per-channel case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if (self.strategy == QuantizationStrategy.TENSOR - and len(self.logical_widths) > 1): - - # Load the N per-tensor scales into the channelwise buffer. - weight_scale_channel = torch.empty( - (sum(self.logical_widths), 1), - dtype=torch.float32, - device=layer.weight_scale.device) - start = 0 - for idx, logical_width in enumerate(self.logical_widths): - end = start + logical_width - weight_scale_channel[start:end, :] = layer.weight_scale[idx] - start = end - - layer.weight_scale = Parameter(weight_scale_channel, - requires_grad=False) - - # transpose weights for cutlass. - weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - self.logical_widths = output_partition_sizes - - # WEIGHT SCALE - shape: Union[Tuple[int], Tuple[int, int]] - if self.strategy == QuantizationStrategy.CHANNEL: - shape = (sum(self.logical_widths), 1) - else: - shape = (len(self.logical_widths), ) - - weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("weight_scale", weight_scale) - if self.strategy == QuantizationStrategy.CHANNEL: - set_weight_attrs(weight_scale, { - "weight_loader": weight_loader, - "output_dim": 0, - }) - else: - set_weight_attrs(weight_scale, { - "weight_loader": weight_loader, - "needs_scalar_to_array": True, - }) - - # WEIGHT - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - requires_grad=False) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - }) - - # INPUT SCALE - # Static quantization: load from disk. - if self.is_static_input_scheme: - input_scale = Parameter(torch.empty(1, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("input_scale", input_scale) - set_weight_attrs(input_scale, { - "weight_loader": weight_loader, - "ignore_warning": True, - }) - # Dynamic quantization: set to None. - else: - layer.input_scale = None - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - # ops.scaled_int8_quant supports both dynamic and static quant. - # * dynamic, layer.input_scale is None and x_scale computed from x. - # * static, layer.input_scale is scalar and x_scale is input_scale. - x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale) - - return ops.cutlass_scaled_mm(x_q, - layer.weight, - scale_a=x_scale, - scale_b=layer.weight_scale, - out_dtype=x.dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py new file mode 100644 index 0000000000000..b93425fb2d629 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -0,0 +1,87 @@ +from typing import Callable, List, Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, + requantize_with_max_scale) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A8Fp8"] + + +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + + def __init__(self, input_dynamic: bool): + self.input_dynamic = input_dynamic + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # W8A8-Fp8 kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), we requantize with a single scale. + def process_weights_after_loading(self, layer) -> None: + # Dequant -> Quant with max scale. + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + # Update layer with new values. + layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False) + layer.weight_scale = torch.nn.Parameter(max_w_scale, + requires_grad=False) + if self.input_dynamic: + layer.input_scale = None + else: + layer.input_scale = torch.nn.Parameter(layer.input_scale.max(), + requires_grad=False) + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + del params_dtype + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = torch.nn.Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + }) + + # WEIGHT SCALE + weight_scale = create_per_tensor_scale_param( + output_partition_sizes, weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if not self.input_dynamic: + input_scale = create_per_tensor_scale_param( + output_partition_sizes, weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py new file mode 100644 index 0000000000000..e70504ec51cb3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -0,0 +1,85 @@ +from typing import Callable, List + +import torch +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param, + create_per_tensor_scale_param) +from vllm.model_executor.utils import set_weight_attrs + + +class CompressedTensorsW8A8Int8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(self.logical_widths) > 1 + if is_fused_module and self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise(layer.weight_scale, + self.logical_widths) + layer.weight_scale = Parameter(ws_channelwise, requires_grad=False) + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + self.logical_widths = output_partition_sizes + + # WEIGHT + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + }) + + # WEIGHT SCALE + layer_kwargs = {"weight_loader": weight_loader} + if self.strategy == QuantizationStrategy.CHANNEL: + scale = create_per_channel_scale_param(output_partition_sizes, + **layer_kwargs) + else: + assert self.strategy == QuantizationStrategy.TENSOR + scale = create_per_tensor_scale_param(output_partition_sizes, + **layer_kwargs) + layer.register_parameter("weight_scale", scale) + + # INPUT SCALE + if self.is_static_input_scheme: + scale = create_per_tensor_scale_param(output_partition_sizes, + **layer_kwargs) + layer.register_parameter("input_scale", scale) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + return apply_int8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index b2bec9b603d1a..5b44c215535b5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -9,6 +9,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + float_quantized = "float-quantized" int_quantized = "int-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 544774891389d..8dba9019f94cf 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import torch from torch.nn import Module @@ -11,11 +11,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState, - marlin_permute_scales) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - pack_fp8_to_int32) + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, + cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once @@ -25,13 +25,6 @@ logger = init_logger(__name__) -def cutlass_fp8_supported() -> bool: - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - - return ops.cutlass_scaled_mm_supports_fp8(capability) - - class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -117,23 +110,6 @@ def __init__(self, quant_config: Fp8Config): capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 - def _create_scale_param( - self, - scale_name: str, - layer: torch.nn.Module, - output_partition_sizes: List[int], - **extra_weight_attrs, - ) -> None: - scale = Parameter(torch.empty(len(output_partition_sizes), - dtype=torch.float32), - requires_grad=False) - scale[:] = torch.finfo(torch.float8_e4m3fn).min - layer.register_parameter(scale_name, scale) - set_weight_attrs(scale, { - **extra_weight_attrs, - "needs_scalar_to_array": True, - }) - def create_weights( self, layer: torch.nn.Module, @@ -147,7 +123,6 @@ def create_weights( del input_size, output_size output_size_per_partition = sum(output_partition_sizes) - layer.process_after_load = True layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -173,144 +148,50 @@ def create_weights( # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - self._create_scale_param( - scale_name="weight_scale", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) + scale = create_per_tensor_scale_param(output_partition_sizes, + **extra_weight_attrs) + layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": - self._create_scale_param( - scale_name="input_scale", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - # For GPUs without FP8 hardware support, we use Marlin for fast - # fused dequantization - if self.use_marlin: - layer.marlin_state = GPTQMarlinState.REPACK - - def prepare_layer_for_marlin(self, layer: Module) -> None: - print_warning_once( - "Your GPU does not have native support for FP8 computation but " - "FP8 quantization is being used. Weight-only FP8 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") - - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - assert layer.marlin_state == GPTQMarlinState.REPACK - layer.marlin_state = GPTQMarlinState.READY - - device = layer.weight.device - - # WEIGHTS - # Repack weights to gptq format (packed int32 elements) - packed_gptq_qweight = pack_fp8_to_int32(layer.weight) - - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=packed_gptq_qweight, - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - scales = layer.weight_scale.repeat(1, part_size_n).to( - layer.orig_dtype).to(device) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, - size_k=part_size_k, - size_n=part_size_n, - group_size=-1, - num_bits=8, - ) - layer.weight_scale = Parameter(marlin_scales, requires_grad=False) - - # Allocate marlin workspace - max_workspace_size = ( - part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) - - layer.workspace = workspace + scale = create_per_tensor_scale_param(output_partition_sizes, + **extra_weight_attrs) + layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: - if (not hasattr(layer, "process_after_load") - or not layer.process_after_load): - return - - # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + + # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) - layer.logical_widths = None layer.input_scale = None - if self.use_marlin: - self.prepare_layer_for_marlin(layer) - return # If checkpoint is fp8, requantize the separately quantized logical # weights into a single fp8 weight with a single weight scale. else: - # WEIGHT_SCALE / WEIGHT - # Loop over logical weights, requantizing with single scale. - max_w_scale = layer.weight_scale.max() - - # QKV / MLP is fused in the on disk checkpoint if any of the - # weight scales are still set to the default since we initialize - # N weight scales for N shards but we only load 1 weight scale - # from disk in this case. As a result, we skip dequant -> requant - # since we already have quantized QKV together. - # Sample Model with fused checkpoint: - # * nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = ( - layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min) - - if unfused_module_in_checkpoint: - start = 0 - for idx, logical_width in enumerate(layer.logical_widths): - end = start + logical_width - weight_dq = per_tensor_dequantize( - layer.weight[start:end, :], layer.weight_scale[idx]) - - layer.weight[start:end, :] = per_tensor_quantize( - weight_dq, layer.weight_scale.max()) - start = end - layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + # Dequant -> Quant with max scale. + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) - # WEIGHT - # Transpose weight for passing to torch._scaled_mm - weight = layer.weight + # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) - - # INPUT ACTIVATION SCALE - # Dynamic: set to None (required input to ops.scaled_fp8_quant). - # Static: set to max of the input_scales (since they are equal). - if self.quant_config.activation_scheme == "dynamic": - layer.input_scale = None - elif self.quant_config.activation_scheme == "static": + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: - raise ValueError( - f"Unknown scheme {self.quant_config.activation_scheme}") + layer.input_scale = None - if self.use_marlin: - self.prepare_layer_for_marlin(layer) + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale def apply(self, layer: torch.nn.Module, @@ -318,65 +199,22 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: if self.use_marlin: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (layer.output_size_per_partition, ) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=layer.weight, - b_scales=layer.weight_scale, + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, workspace=layer.workspace, - num_bits=8, - size_m=reshaped_x.shape[0], size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - else: - - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x - # If static, layer.input_scale is scalar and x_scale is input_scale + bias=bias) - if bias is None and self.cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) - - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm( - qinput, - layer.weight, - out_dtype=x.dtype, - scale_a=x_scale, - scale_b=layer.weight_scale, - ) - - else: - qinput, x_scale = ops.scaled_fp8_quant(x, - layer.input_scale, - batch_dim_padding=17) - - # Fused GEMM_DQ -- note we padded the input above because - # torch._scaled_mm is more performant for matrices with - # batch dimension > 16. Note that this could change - # in the future. - output, _ = torch._scaled_mm( - qinput, - layer.weight, - out_dtype=x.dtype, - scale_a=x_scale, - scale_b=layer.weight_scale, - bias=bias, - ) - - return torch.narrow(output, 0, 0, x.shape[0]) + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported) class Fp8MoEMethod(FusedMoEMethodBase): @@ -399,8 +237,6 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - layer.process_after_load = True - if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -465,9 +301,6 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, layer.a2_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if (not hasattr(layer, "process_after_load") - or not layer.process_after_load): - return # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: @@ -531,7 +364,7 @@ def process_weights_after_loading(self, layer: Module) -> None: shard_size, :], layer.w13_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ - start:start + shard_size, :] = per_tensor_quantize( + start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size @@ -596,23 +429,3 @@ def process_weights_after_loading(self, layer: Module) -> None: "cause accuracy issues. Please make sure kv-cache scaling " "factor is available in the fp8 checkpoint.") del layer.kv_scale - - -def per_tensor_quantize(tensor: torch.Tensor, - inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight.to(torch.float8_e4m3fn) - - -def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, - torch.Tensor]) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) - dq_weight = fake_qweight * inv_scale - return dq_weight - - -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index a6284d0ed7b1b..6b971f73d45bf 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -11,20 +11,16 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K, + GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM, + GPTQ_MARLIN_TILE) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform logger = init_logger(__name__) -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -GPTQ_MARLIN_SUPPORTED_SYM = [True] - # Permutations for Marlin scale shuffling def get_scale_perms(num_bits: int): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 66ce1959207cc..9886245269ad3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,9 +1,11 @@ """This file is used for /tests and /benchmarks""" import random +from typing import Optional import numpy import torch +from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.format_24 import ( mask_creator, sparse_semi_structured_from_dense_cutlass) from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( @@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_pack_factor, quantize_weights, sort_weights) from vllm.platforms import current_platform +from vllm.utils import print_warning_once -MARLIN_TILE = 16 +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_SUPPORTED_SYM = [True] def is_marlin_supported(): @@ -22,7 +32,92 @@ def is_marlin_supported(): return capability[0] >= 8 -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: + print_warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WEIGHTS + # Repack weights to gptq format (packed int32 elements) + packed_gptq_qweight = pack_fp8_to_int32(layer.weight) + + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_gptq_qweight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Currently Marlin doesn't support per-tensor scales, so we + # expand it to channelwise + scales = layer.weight_scale.repeat(1, part_size_n).to( + layer.orig_dtype).to(device) + # Permute scales + num_bits = 8 + marlin_scales = marlin_permute_scales( + s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1, + scale_perm=marlin_scale_perm[num_bits], + scale_perm_single=marlin_scale_perm_single[num_bits]) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + # Allocate marlin workspace + max_workspace_size = (part_size_n // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + layer.workspace = workspace + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py new file mode 100644 index 0000000000000..81b7fdb7833d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -0,0 +1,163 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + + +def cutlass_fp8_supported() -> bool: + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def create_per_tensor_scale_param( + output_partition_sizes: List[int], + **extra_weight_attrs, +) -> Parameter: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, { + "needs_scalar_to_array": True, + **extra_weight_attrs + }) + return scale + + +def create_per_channel_scale_param(output_partition_sizes: List[int], + **extra_weight_attrs) -> Parameter: + scale = Parameter(torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + requires_grad=False) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs}) + return scale + + +def convert_to_channelwise( + weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Create channelwise buffer + weight_scale_channel = torch.empty((sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device) + + # Expand each scale to match the size of each logical matrix. + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = weight_scale[idx] + start = end + + return weight_scale_channel + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo( + torch.float8_e4m3fn).min) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant( + weight_dq, max_w_scale) + start = end + + return max_w_scale, weight + + +def apply_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cutlass_fp8_supported: bool = True, +) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + if bias is None and cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale) + + else: + qinput, x_scale = ops.scaled_fp8_quant(input, + input_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + + return torch.narrow(output, 0, 0, input.shape[0]) + + +def apply_int8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): + if bias is not None: + raise NotImplementedError("W8A8 with int8 does not yet support bias.") + + # ops.scaled_int8_quant supports both dynamic and static quant. + # * dynamic, layer.input_scale is None and x_scale computed from x. + # * static, layer.input_scale is scalar and x_scale is input_scale. + x_q, x_scale = ops.scaled_int8_quant(input, input_scale) + + return ops.cutlass_scaled_mm(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype)