Skip to content

Commit 72a458f

Browse files
committed
initial commit on compressed-tensors quantization support for fp8
Signed-off-by: Han Qi <hanq@google.com>
1 parent 60c14f5 commit 72a458f

File tree

5 files changed

+403
-28
lines changed

5 files changed

+403
-28
lines changed

tpu_inference/layers/vllm/quantization/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
2121
None: VllmUnquantizedConfig,
2222
"compressed-tensors": VllmCompressedTensorsConfig,
2323
"awq": VllmAWQConfig,
24+
"fp8": VllmCompressedTensorsConfig,
2425
}
26+
# import sys
2527

28+
# sys.stdin = open(0)
29+
# breakpoint()
2630
if model_config.quantization not in method_to_config:
27-
raise NotImplementedError
31+
raise NotImplementedError(
32+
f"{model_config.quantization} quantization method not supported.")
2833
quant_config = method_to_config[model_config.quantization]
2934
assert issubclass(quant_config, JaxCommonConfig)
3035
quant_config.set_configs(vllm_config, mesh)

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
1515
CompressedTensorsLinearMethod, CompressedTensorsScheme)
1616
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17-
find_matched_target, is_activation_quantization_format,
18-
should_ignore_layer)
17+
find_matched_target, should_ignore_layer)
1918

2019
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
20+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
21+
CompressedTensorsW8A8Fp8MoEMethod
2122
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
2223
VllmCompressedTensorsW8A8Fp8
2324
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
@@ -60,12 +61,12 @@ def get_scheme(self,
6061
layer_name=layer_name,
6162
module=layer,
6263
targets=self.target_scheme_map.keys(),
63-
fused_mapping=self.packed_modules_mapping)
64+
fused_mapping=self.packed_modules_mapping,
65+
)
6466

6567
scheme_dict = self.target_scheme_map[matched_target]
6668
weight_quant = scheme_dict.get("weights")
6769
input_quant = scheme_dict.get("input_activations")
68-
format = scheme_dict.get("format")
6970

7071
if weight_quant is None:
7172
logger.warning_once("Acceleration for non-quantized schemes is "
@@ -74,10 +75,10 @@ def get_scheme(self,
7475
return None
7576

7677
# TODO(kyuyeunk): Add support for different act_quant_format
77-
act_quant_format = is_activation_quantization_format( # noqa: F841
78-
format
79-
) if format is not None else is_activation_quantization_format(
80-
self.quant_format)
78+
# act_quant_format = (
79+
# is_activation_quantization_format( # noqa: F841
80+
# format) if format is not None else
81+
# is_activation_quantization_format(self.quant_format))
8182

8283
linear_config = self.get_linear_config(layer)
8384
if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -114,8 +115,9 @@ def get_quant_method(
114115
layer.scheme = scheme
115116
return CompressedTensorsLinearMethod(self)
116117
if isinstance(layer, FusedMoE):
117-
raise NotImplementedError(
118-
"FusedMoE quantization is currently not supported.")
118+
print("HERE", layer)
119+
return CompressedTensorsW8A8Fp8MoEMethod(self, layer.quant_config,
120+
self.mesh)
119121
if isinstance(layer, Attention):
120122
return CompressedTensorsKVCacheMethod(self)
121123
return None

0 commit comments

Comments
 (0)