Skip to content

Commit

Permalink
[ Misc ] Support Fp8 via llm-compressor (vllm-project#6110)
Browse files Browse the repository at this point in the history
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
2 people authored and Alvant committed Oct 26, 2024
1 parent 39cea07 commit 1c8c2bb
Show file tree
Hide file tree
Showing 17 changed files with 603 additions and 372 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.752
- name: "exact_match,flexible-extract"
value: 0.752
limit: 250
num_fewshot: 5
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
tasks:
- name: "gsm8k"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.728
- name: "exact_match,flexible-extract"
value: 0.728
limit: 250
num_fewshot: 5
2 changes: 2 additions & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done

lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE
3 changes: 2 additions & 1 deletion .buildkite/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

def launch_lm_eval(eval_config):
model_args = f"pretrained={eval_config['model_name']}," \
f"tensor_parallel_size={TP_SIZE}"
f"tensor_parallel_size={TP_SIZE}," \
f"add_bos_token=true"

results = lm_eval.simple_evaluate(
model="vllm",
Expand Down
32 changes: 27 additions & 5 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8, CompressedTensorsWNA16)
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)

Expand Down Expand Up @@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.is_static_input_scheme
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
torch.float8_e4m3fn)
expected_type = torch.int8

assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
Expand Down Expand Up @@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output


def test_compressed_tensors_fp8(vllm_runner):
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0

sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8, CompressedTensorsWNA16)
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
QuantizationType, find_first_name_or_class_match)
from vllm.platforms import current_platform


Expand Down Expand Up @@ -117,6 +118,40 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,

return is_8_bits and is_token and is_symmetric and is_dynamic

def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False

# Confirm we have floating points.
if not (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT):
return False

# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_weight = (
weight_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_weight):
return False

# Dynamic quantization is always supported if weights supported.
if input_quant.dynamic:
return True

# Confirm activation scheme is supported.
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_activation and is_per_tensor_activation):
return False

# All conditions satisfied.
return True

def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
Expand Down Expand Up @@ -147,14 +182,21 @@ def _get_schema(self, weight_quant: BaseModel,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)

if self.quant_format == CompressionFormat.int_quantized.value:
if (self.quant_format == CompressionFormat.int_quantized.value or
self.quant_format == CompressionFormat.float_quantized.value):
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
input_dynamic=input_quant.dynamic)

if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
is_static_input_scheme=True)
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
is_static_input_scheme=False)
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down Expand Up @@ -187,7 +229,7 @@ def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return layer.scheme.process_weights_after_loading(layer)
layer.scheme.process_weights_after_loading(layer)

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16_24 import ( # noqa: F401
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)

__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
]

This file was deleted.

Loading

0 comments on commit 1c8c2bb

Please sign in to comment.