1414 CompressedTensorsConfig , CompressedTensorsKVCacheMethod ,
1515 CompressedTensorsLinearMethod , CompressedTensorsScheme )
1616from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
17- find_matched_target , is_activation_quantization_format ,
18- should_ignore_layer )
17+ find_matched_target , should_ignore_layer )
1918
2019from tpu_inference .layers .vllm .quantization .common import JaxCommonConfig
20+ from tpu_inference .layers .vllm .quantization .compressed_tensors .compressed_tensors_moe import \
21+ CompressedTensorsW8A8Fp8MoEMethod
2122from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_fp8 import \
2223 VllmCompressedTensorsW8A8Fp8
2324from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_int8 import \
@@ -60,12 +61,12 @@ def get_scheme(self,
6061 layer_name = layer_name ,
6162 module = layer ,
6263 targets = self .target_scheme_map .keys (),
63- fused_mapping = self .packed_modules_mapping )
64+ fused_mapping = self .packed_modules_mapping ,
65+ )
6466
6567 scheme_dict = self .target_scheme_map [matched_target ]
6668 weight_quant = scheme_dict .get ("weights" )
6769 input_quant = scheme_dict .get ("input_activations" )
68- format = scheme_dict .get ("format" )
6970
7071 if weight_quant is None :
7172 logger .warning_once ("Acceleration for non-quantized schemes is "
@@ -74,10 +75,10 @@ def get_scheme(self,
7475 return None
7576
7677 # TODO(kyuyeunk): Add support for different act_quant_format
77- act_quant_format = is_activation_quantization_format ( # noqa: F841
78- format
79- ) if format is not None else is_activation_quantization_format (
80- self .quant_format )
78+ # act_quant_format = (
79+ # is_activation_quantization_format( # noqa: F841
80+ # format ) if format is not None else
81+ # is_activation_quantization_format( self.quant_format) )
8182
8283 linear_config = self .get_linear_config (layer )
8384 if self ._is_fp8_w8a8 (weight_quant , input_quant ):
@@ -114,8 +115,9 @@ def get_quant_method(
114115 layer .scheme = scheme
115116 return CompressedTensorsLinearMethod (self )
116117 if isinstance (layer , FusedMoE ):
117- raise NotImplementedError (
118- "FusedMoE quantization is currently not supported." )
118+ print ("HERE" , layer )
119+ return CompressedTensorsW8A8Fp8MoEMethod (self , layer .quant_config ,
120+ self .mesh )
119121 if isinstance (layer , Attention ):
120122 return CompressedTensorsKVCacheMethod (self )
121123 return None
0 commit comments