1212 QuantizationStrategy ,
1313 QuantizationType )
1414from compressed_tensors .transform import TransformConfig
15- from pydantic import BaseModel
1615
1716import vllm .envs as envs
1817from vllm .logger import init_logger
@@ -268,7 +267,8 @@ def _check_scheme_supported(self,
268267 else :
269268 return False
270269
271- def _is_fp4a4_nvfp4 (self , weight_quant : BaseModel , input_quant : BaseModel ):
270+ def _is_fp4a4_nvfp4 (self , weight_quant : QuantizationArgs ,
271+ input_quant : QuantizationArgs ):
272272
273273 if weight_quant is None or input_quant is None :
274274 return False
@@ -288,8 +288,8 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
288288 return (is_tensor_group_quant and is_float_type and is_4_bits
289289 and is_group_size_16 and is_symmetric )
290290
291- def _is_fp4a16_nvfp4 (self , weight_quant : BaseModel ,
292- input_quant : BaseModel ):
291+ def _is_fp4a16_nvfp4 (self , weight_quant : QuantizationArgs ,
292+ input_quant : QuantizationArgs ):
293293
294294 is_weight_only = weight_quant is not None and input_quant is None
295295 is_tensor_group_quant = (
@@ -303,8 +303,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
303303 return (is_weight_only and is_tensor_group_quant and is_float_type
304304 and is_4_bits and is_group_size_16 and is_symmetric )
305305
306- def _is_static_tensor_w8a8 (self , weight_quant : BaseModel ,
307- input_quant : BaseModel ) -> bool :
306+ def _is_static_tensor_w8a8 (self , weight_quant : QuantizationArgs ,
307+ input_quant : QuantizationArgs ) -> bool :
308308 is_8_bits = weight_quant .num_bits == input_quant .num_bits == 8
309309 weight_strategy = (
310310 weight_quant .strategy == QuantizationStrategy .TENSOR .value
@@ -317,8 +317,8 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
317317 # Only symmetric weight quantization supported.
318318 return is_8_bits and is_tensor and weight_quant .symmetric and is_static
319319
320- def _is_dynamic_token_w8a8 (self , weight_quant : BaseModel ,
321- input_quant : BaseModel ) -> bool :
320+ def _is_dynamic_token_w8a8 (self , weight_quant : QuantizationArgs ,
321+ input_quant : QuantizationArgs ) -> bool :
322322 is_8_bits = weight_quant .num_bits == input_quant .num_bits == 8
323323 weight_strategy = (
324324 weight_quant .strategy == QuantizationStrategy .TENSOR .value
@@ -331,8 +331,8 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
331331 # Only symmetric weight quantization supported.
332332 return is_8_bits and is_token and weight_quant .symmetric and is_dynamic
333333
334- def _is_dynamic_token_w4a8_int (self , weight_quant : BaseModel ,
335- input_quant : BaseModel ) -> bool :
334+ def _is_dynamic_token_w4a8_int (self , weight_quant : QuantizationArgs ,
335+ input_quant : QuantizationArgs ) -> bool :
336336 is_weight_4_bits = weight_quant .num_bits == 4
337337 is_activation_8_bits = input_quant .num_bits == 8
338338 weight_strategy = (
@@ -347,8 +347,8 @@ def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
347347 return (is_weight_4_bits and is_activation_8_bits and is_token
348348 and weight_quant .symmetric and is_dynamic )
349349
350- def _is_fp8_w8a8 (self , weight_quant : BaseModel ,
351- input_quant : BaseModel ) -> bool :
350+ def _is_fp8_w8a8 (self , weight_quant : QuantizationArgs ,
351+ input_quant : QuantizationArgs ) -> bool :
352352 # Confirm weights and activations quantized.
353353 if weight_quant is None or input_quant is None :
354354 return False
@@ -358,11 +358,12 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
358358 and input_quant .type == QuantizationType .FLOAT )
359359 is_symmetric_weight = weight_quant .symmetric
360360 is_static_weight = not weight_quant .dynamic
361- is_per_tensor_or_channel_weight = (weight_quant .strategy in [
362- QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL
361+ is_tensor_or_channel_or_block_weight = (weight_quant .strategy in [
362+ QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL ,
363+ QuantizationStrategy .BLOCK
363364 ])
364365 if not (is_floating_point and is_symmetric_weight and is_static_weight
365- and is_per_tensor_or_channel_weight ):
366+ and is_tensor_or_channel_or_block_weight ):
366367 return False
367368
368369 # Dynamic quantization is always supported if weights supported.
@@ -375,8 +376,8 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
375376 input_quant .strategy == QuantizationStrategy .TENSOR )
376377 return is_symmetric_activation and is_per_tensor_activation
377378
378- def _is_fp8_w4a8 (self , weight_quant : BaseModel ,
379- input_quant : BaseModel ) -> bool :
379+ def _is_fp8_w4a8 (self , weight_quant : QuantizationArgs ,
380+ input_quant : QuantizationArgs ) -> bool :
380381 if not weight_quant or not input_quant :
381382 return False
382383 is_weight_4_bits = weight_quant .num_bits == 4
@@ -392,24 +393,24 @@ def _is_fp8_w4a8(self, weight_quant: BaseModel,
392393 return (is_weight_4_bits and is_activation_8_bits and is_token
393394 and is_symmetric and is_dynamic )
394395
395- def _is_fp8_w4a8_sm90 (self , weight_quant : BaseModel ,
396- input_quant : BaseModel ) -> bool :
396+ def _is_fp8_w4a8_sm90 (self , weight_quant : QuantizationArgs ,
397+ input_quant : QuantizationArgs ) -> bool :
397398 return (self ._check_scheme_supported (90 , error = False , match_exact = True )
398399 and self ._is_fp8_w4a8 (weight_quant , input_quant ))
399400
400- def _is_fp8_w8a8_sm90 (self , weight_quant : BaseModel ,
401- input_quant : BaseModel ) -> bool :
401+ def _is_fp8_w8a8_sm90 (self , weight_quant : QuantizationArgs ,
402+ input_quant : QuantizationArgs ) -> bool :
402403 return (self ._check_scheme_supported (90 , error = False , match_exact = True )
403404 and self ._is_fp8_w8a8 (weight_quant , input_quant ))
404405
405- def _is_fp8_w8a8_sm100 (self , weight_quant : BaseModel ,
406- input_quant : BaseModel ) -> bool :
406+ def _is_fp8_w8a8_sm100 (self , weight_quant : QuantizationArgs ,
407+ input_quant : QuantizationArgs ) -> bool :
407408 return (self ._check_scheme_supported (
408409 100 , error = False , match_exact = True )
409410 and self ._is_fp8_w8a8 (weight_quant , input_quant ))
410411
411- def _is_fp8_w8a16 (self , weight_quant : BaseModel ,
412- input_quant : BaseModel ) -> bool :
412+ def _is_fp8_w8a16 (self , weight_quant : QuantizationArgs ,
413+ input_quant : QuantizationArgs ) -> bool :
413414 # Confirm weights quantized.
414415 if weight_quant is None :
415416 return False
@@ -421,18 +422,19 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel,
421422 # Confirm weight scheme is supported.
422423 is_symmetric_weight = weight_quant .symmetric
423424 is_static_weight = not weight_quant .dynamic
424- is_per_tensor_or_channel_weight = (weight_quant .strategy in [
425- QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL
425+ is_tensor_or_channel_or_block_weight = (weight_quant .strategy in [
426+ QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL ,
427+ QuantizationStrategy .BLOCK
426428 ])
427429 if not (is_symmetric_weight and is_static_weight # noqa: SIM103
428- and is_per_tensor_or_channel_weight ):
430+ and is_tensor_or_channel_or_block_weight ):
429431 return False
430432
431433 # All conditions satisfied.
432434 return True
433435
434- def _is_wNa16_group_channel (self , weight_quant : BaseModel ,
435- input_quant : BaseModel ) -> bool :
436+ def _is_wNa16_group_channel (self , weight_quant : QuantizationArgs ,
437+ input_quant : QuantizationArgs ) -> bool :
436438 input_quant_none = input_quant is None
437439 is_channel_group = (
438440 weight_quant .strategy == QuantizationStrategy .CHANNEL .value
@@ -443,8 +445,8 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
443445
444446 def _get_scheme_from_parts (
445447 self ,
446- weight_quant : BaseModel ,
447- input_quant : BaseModel ,
448+ weight_quant : QuantizationArgs ,
449+ input_quant : QuantizationArgs ,
448450 format : Optional [str ] = None ) -> "CompressedTensorsScheme" :
449451
450452 # use the per-layer format if defined, otherwise, use global format
@@ -496,7 +498,7 @@ def _get_scheme_from_parts(
496498 CompressedTensorsW8A8Fp8 .get_min_capability (), error = False )
497499 if is_fp8_w8a8_supported :
498500 return CompressedTensorsW8A8Fp8 (
499- strategy = weight_quant . strategy ,
501+ weight_quant = weight_quant ,
500502 is_static_input_scheme = (input_quant
501503 and not input_quant .dynamic ))
502504 else :
0 commit comments