1111from vllm .logger import init_logger
1212from vllm .platforms import current_platform
1313from vllm .scalar_type import ScalarType
14+ from vllm .utils import round_up
1415
1516logger = init_logger (__name__ )
1617
@@ -785,7 +786,7 @@ def scaled_fp4_quant(
785786 Returns:
786787 Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
787788 two values are packed into a uint8 and float8_e4m3 scaling factors
788- in a sizzled layout.
789+ in the sizzled layout.
789790 """
790791 assert input .ndim >= 1 , (
791792 f'input.ndim needs to be >= 1, but got { input .ndim } .' )
@@ -803,11 +804,14 @@ def scaled_fp4_quant(
803804 # Two fp4 values will be packed into an uint8.
804805 output = torch .empty ((m , n // 2 ), device = device , dtype = torch .uint8 )
805806
806- # We use the rounded values to store the swizzled values. Then, the scaling
807- # factors in float8_e4m3fn are packed into an int32 for every 4 values.
808- rounded_m = ((m + 128 - 1 ) // 128 ) * 128
807+ # We use the rounded values to store the swizzled values. Due to the
808+ # requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
809+ # So, we first pad the scales to multiples of 128 and 4. Then, the scales
810+ # (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
811+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
812+ rounded_m = round_up (m , 128 )
809813 scale_n = n // block_size
810- rounded_n = (( scale_n + 4 - 1 ) // 4 ) * 4
814+ rounded_n = round_up ( scale_n , 4 )
811815 output_scale = torch .empty ((rounded_m , rounded_n // 4 ),
812816 device = device ,
813817 dtype = torch .int32 )
0 commit comments