3434 FakeQuantizeConfigBase ,
3535 Float8FakeQuantizeConfig ,
3636 IntxFakeQuantizeConfig ,
37- NVFP4FakeQuantizeConfig ,
3837)
3938from .utils import (
4039 _fake_quantize_per_channel_group ,
@@ -58,6 +57,12 @@ def __repr__(self) -> str:
5857
5958 @staticmethod
6059 def from_config (config : FakeQuantizeConfigBase ) -> "FakeQuantizerBase" :
60+ # TODO: rewrite using registration API so we don't need to import here
61+ from torchao .prototype .qat import (
62+ NVFP4FakeQuantizeConfig ,
63+ NVFP4FakeQuantizer ,
64+ )
65+
6166 if isinstance (config , IntxFakeQuantizeConfig ):
6267 return IntxFakeQuantizer (config )
6368 elif isinstance (config , Float8FakeQuantizeConfig ):
@@ -95,52 +100,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
95100 return dq
96101
97102
98- class NVFP4FakeQuantizer (FakeQuantizerBase ):
99- """
100- (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101- """
102-
103- def __init__ (self , config : NVFP4FakeQuantizeConfig ):
104- super ().__init__ ()
105- torch ._C ._log_api_usage_once ("torchao.quantization.qat.NVFP4FakeQuantizer" )
106- self .config = config
107-
108- def forward (self , x : torch .Tensor ) -> torch .Tensor :
109- from torchao .prototype .mx_formats .nvfp4_tensor import (
110- _nvfp4_quantize ,
111- per_tensor_amax_to_scale ,
112- )
113-
114- block_size = 16
115- original_shape = x .shape
116- if x .dim () == 3 :
117- x = x .view (- 1 , x .shape [- 1 ])
118- if self .config .use_per_tensor_scale :
119- tensor_amax = torch .max (torch .abs (x ))
120- per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
121- else :
122- per_tensor_scale = None
123-
124- # quantize
125- scale , q = _nvfp4_quantize (
126- x ,
127- block_size = block_size ,
128- per_tensor_scale = per_tensor_scale ,
129- skip_dtype_cast_and_packing = True ,
130- )
131- if self .config .use_per_tensor_scale :
132- scale = scale * per_tensor_scale
133- assert q .dtype == x .dtype
134- assert scale .dtype == torch .float32
135-
136- # dequantize
137- M , K = q .shape [0 ], q .shape [1 ]
138- q = q .view (M , K // block_size , block_size )
139- scale = scale .view (M , K // block_size , 1 )
140- dq = q * scale
141- return dq .view (original_shape ).to (x .dtype )
142-
143-
144103class IntxFakeQuantizer (FakeQuantizerBase ):
145104 """
146105 Generic module for applying integer fake quantization to a tensor, as specified in the config.
0 commit comments