5454from torchao .dtypes .utils import Layout
5555from torchao .float8 .config import e4m3_dtype , e5m2_dtype
5656from torchao .float8 .float8_linear import Float8Linear
57- from torchao .float8 .inference import Float8MMConfig
57+ from torchao .float8 .inference import (
58+ Float8MMConfig ,
59+ FP8Granularity ,
60+ _check_hardware_support ,
61+ _normalize_granularity ,
62+ )
5863from torchao .quantization .linear_activation_weight_observed_tensor import (
5964 LinearActivationWeightObservedTensor ,
6065)
@@ -1431,56 +1436,9 @@ def _float8_weight_only_transform(
14311436 return module
14321437
14331438
1434- _fp8_granularities = Union [PerTensor , PerRow ]
1435-
1436-
1437- # Validate and process granularity input
1438- def _normalize_granularity (
1439- granularity : Optional [
1440- Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1441- ],
1442- ) -> Tuple [_fp8_granularities , _fp8_granularities ]:
1443- processed_granularity = None
1444- if granularity is None :
1445- processed_granularity = (PerTensor (), PerTensor ())
1446- elif isinstance (granularity , (PerTensor , PerRow )):
1447- processed_granularity = (granularity , granularity )
1448- elif isinstance (granularity , tuple ) and len (granularity ) == 2 :
1449- if not (
1450- isinstance (granularity [0 ], (PerTensor , PerRow ))
1451- and isinstance (granularity [1 ], (PerTensor , PerRow ))
1452- ):
1453- raise ValueError (
1454- f"Invalid granularity types: { granularity } , only PerTensor or PerRow are supported."
1455- )
1456- if not isinstance (granularity [0 ], type (granularity [1 ])):
1457- raise ValueError (
1458- f"Different granularities for activation and weight are not supported: { granularity } , only PerTensor or PerRow are supported."
1459- )
1460- processed_granularity = granularity
1461- else :
1462- raise ValueError (
1463- f"Invalid granularity specification: { granularity } , only PerTensor or PerRow are supported."
1464- )
1465- # Validate granularity with supported Hardware
1466- for _granularity in processed_granularity :
1467- if isinstance (_granularity , PerTensor ):
1468- assert is_sm_at_least_89 () or is_MI300 (), (
1469- "PerTensor quantization only works for CUDA>=8.9 and MI300+"
1470- )
1471- elif isinstance (_granularity , PerRow ):
1472- assert is_sm_at_least_90 () or is_MI300 (), (
1473- "PerRow quantization only works for CUDA>=9.0 and MI300+"
1474- )
1475- else :
1476- raise ValueError (f"Invalid granularity type: { _granularity } " )
1477-
1478- return processed_granularity
1479-
1480-
14811439def _input_activation_quant_func_fp8 (
14821440 x : torch .Tensor ,
1483- activation_granularity : _fp8_granularities ,
1441+ activation_granularity : FP8Granularity ,
14841442 activation_dtype : torch .dtype ,
14851443 scale : Optional [torch .Tensor ] = None ,
14861444 zero_point : Optional [torch .Tensor ] = None ,
@@ -1567,7 +1525,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15671525 activation_dtype : torch .dtype = e4m3_dtype
15681526 weight_dtype : torch .dtype = e4m3_dtype
15691527 granularity : Optional [
1570- Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1528+ Union [FP8Granularity , Tuple [FP8Granularity , FP8Granularity ]]
15711529 ] = None
15721530 mm_config : Optional [Float8MMConfig ] = None
15731531 set_inductor_config : bool = True
@@ -1576,6 +1534,11 @@ def __post_init__(self):
15761534 if self .mm_config is None :
15771535 self .mm_config = Float8MMConfig (use_fast_accum = True )
15781536
1537+ activation_granularity , weight_granularity = _normalize_granularity (
1538+ self .granularity
1539+ )
1540+ self .granularity = (activation_granularity , weight_granularity )
1541+
15791542
15801543# for bc
15811544float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
@@ -1587,7 +1550,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15871550 granularity = config .granularity
15881551 mm_config = config .mm_config
15891552
1590- activation_granularity , weight_granularity = _normalize_granularity (granularity )
1553+ # Ensure works on device
1554+ _check_hardware_support (granularity )
1555+ activation_granularity , weight_granularity = granularity
15911556
15921557 if not _fp8_mm_compat (weight ):
15931558 # TODO(future PR): this should really throw an exception instead of silently
@@ -1704,7 +1669,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
17041669 activation_dtype : torch .dtype = e4m3_dtype
17051670 weight_dtype : torch .dtype = e4m3_dtype
17061671 granularity : Optional [
1707- Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1672+ Union [FP8Granularity , Tuple [FP8Granularity , FP8Granularity ]]
17081673 ] = None
17091674 mm_config : Optional [Float8MMConfig ] = None
17101675 set_inductor_config : bool = True
0 commit comments