66from vllm .logger import init_logger
77from vllm .model_executor .layers .fused_moe .layer import FusedMoE
88from vllm .model_executor .layers .linear import LinearBase
9- from vllm .model_executor .layers .quantization import register_quantization_config
10- from vllm .model_executor .layers .quantization .base_config import QuantizeMethodBase # noqa: E501
9+ from vllm .model_executor .layers .quantization import \
10+ register_quantization_config
11+ from vllm .model_executor .layers .quantization .base_config import \
12+ QuantizeMethodBase # noqa: E501
1113from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import (
12- CompressedTensorsConfig ,
13- CompressedTensorsKVCacheMethod ,
14- CompressedTensorsLinearMethod ,
15- CompressedTensorsScheme ,
16- )
17- from tpu_inference .layers .vllm .quantization .compressed_tensors .compressed_tensors_moe import (
18- CompressedTensorsW8A8Fp8MoEMethod ,
19- )
14+ CompressedTensorsConfig , CompressedTensorsKVCacheMethod ,
15+ CompressedTensorsLinearMethod , CompressedTensorsScheme )
2016from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
21- find_matched_target ,
22- is_activation_quantization_format ,
23- should_ignore_layer ,
24- )
17+ find_matched_target , should_ignore_layer )
2518
2619from tpu_inference .layers .vllm .quantization .common import JaxCommonConfig
27- from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_fp8 import (
28- VllmCompressedTensorsW8A8Fp8 ,
29- )
30- from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_int8 import (
31- VllmCompressedTensorsW8A8Int8 ,
32- )
33- from tpu_inference .layers .vllm .quantization .unquantized import VllmUnquantizedConfig
20+ from tpu_inference .layers .vllm .quantization .compressed_tensors .compressed_tensors_moe import \
21+ CompressedTensorsW8A8Fp8MoEMethod
22+ from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_fp8 import \
23+ VllmCompressedTensorsW8A8Fp8
24+ from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_int8 import \
25+ VllmCompressedTensorsW8A8Int8
26+ from tpu_inference .layers .vllm .quantization .unquantized import \
27+ VllmUnquantizedConfig
3428
3529P = PartitionSpec
3630logger = init_logger (__name__ )
3731
3832
3933@register_quantization_config ("jax-compressed-tensors" )
4034class VllmCompressedTensorsConfig (CompressedTensorsConfig , JaxCommonConfig ):
35+
4136 @classmethod
4237 def get_name (cls ) -> str :
4338 return "jax-compressed-tensors"
4439
45- def get_scheme (
46- self , layer : torch .nn .Module , layer_name : Optional [str ] = None
47- ) -> Optional ["CompressedTensorsScheme" ]:
40+ def get_scheme (self ,
41+ layer : torch .nn .Module ,
42+ layer_name : Optional [str ] = None
43+ ) -> Optional ["CompressedTensorsScheme" ]:
4844 """
4945 compressed-tensors supports non uniform in the following way:
5046
@@ -71,24 +67,18 @@ def get_scheme(
7167 scheme_dict = self .target_scheme_map [matched_target ]
7268 weight_quant = scheme_dict .get ("weights" )
7369 input_quant = scheme_dict .get ("input_activations" )
74- format = scheme_dict .get ("format" )
7570
7671 if weight_quant is None :
77- logger .warning_once (
78- "Acceleration for non-quantized schemes is "
79- "not supported by Compressed Tensors. "
80- "Falling back to UnquantizedLinearMethod"
81- )
72+ logger .warning_once ("Acceleration for non-quantized schemes is "
73+ "not supported by Compressed Tensors. "
74+ "Falling back to UnquantizedLinearMethod" )
8275 return None
8376
8477 # TODO(kyuyeunk): Add support for different act_quant_format
85- act_quant_format = (
86- is_activation_quantization_format ( # noqa: F841
87- format
88- )
89- if format is not None
90- else is_activation_quantization_format (self .quant_format )
91- )
78+ # act_quant_format = (
79+ # is_activation_quantization_format( # noqa: F841
80+ # format) if format is not None else
81+ # is_activation_quantization_format(self.quant_format))
9282
9383 linear_config = self .get_linear_config (layer )
9484 if self ._is_fp8_w8a8 (weight_quant , input_quant ):
@@ -105,28 +95,29 @@ def get_scheme(
10595 input_symmetric = input_quant .symmetric ,
10696 jax_config = linear_config ,
10797 )
108- raise NotImplementedError ("No compressed-tensors compatible scheme was found." )
98+ raise NotImplementedError (
99+ "No compressed-tensors compatible scheme was found." )
109100
110101 def get_quant_method (
111102 self ,
112103 layer : torch .nn .Module ,
113104 prefix : str ,
114105 ) -> Optional [QuantizeMethodBase ]:
115- if should_ignore_layer (
116- prefix , ignore = self .ignore , fused_mapping = self . packed_modules_mapping
117- ):
106+ if should_ignore_layer (prefix ,
107+ ignore = self .ignore ,
108+ fused_mapping = self . packed_modules_mapping ):
118109 return VllmUnquantizedConfig .get_quant_method (self , layer , prefix )
119110 if isinstance (layer , LinearBase ):
120111 scheme = self .get_scheme (layer = layer , layer_name = prefix )
121112 if scheme is None :
122- return VllmUnquantizedConfig .get_quant_method (self , layer , prefix )
113+ return VllmUnquantizedConfig .get_quant_method (
114+ self , layer , prefix )
123115 layer .scheme = scheme
124116 return CompressedTensorsLinearMethod (self )
125117 if isinstance (layer , FusedMoE ):
126118 print ("HERE" , layer )
127- return CompressedTensorsW8A8Fp8MoEMethod (
128- self , layer .quant_config , self .mesh
129- )
119+ return CompressedTensorsW8A8Fp8MoEMethod (self , layer .quant_config ,
120+ self .mesh )
130121 if isinstance (layer , Attention ):
131122 return CompressedTensorsKVCacheMethod (self )
132123 return None
0 commit comments