11from dataclasses import dataclass
2+ from typing import Optional
23
34import torch
45
56from torchao .prototype .mx_formats .nvfp4_tensor import (
6- _nvfp4_quantize ,
7+ NVFP4Tensor ,
8+ _addmm_nvfp4_dispatch ,
79 per_tensor_amax_to_scale ,
810)
9- from torchao .quantization .qat import (
10- FakeQuantizeConfigBase ,
11- FakeQuantizerBase ,
12- )
11+ from torchao .quantization .qat import FakeQuantizeConfigBase
1312
1413
1514@dataclass
@@ -23,47 +22,166 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322 Args:
2423 use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524 after the initial fp8 (e4m3) block-wise scaling (default True)
25+ use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+ use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627 """
2728
2829 use_per_tensor_scale : bool = True
30+ use_swizzled_scales : bool = False
31+ use_triton_kernel : bool = False
32+
33+
34+ # TODO: support emulation on non-Blackwell GPUs
35+ class _NVFP4QuantizedForwardFakeQuantizedBackward (torch .autograd .Function ):
36+ """
37+ Autograd function for NVFP4 quantization + addmm in low precision during forward,
38+ and fake quantization in high precision during backward.
39+ """
40+
41+ @staticmethod
42+ def forward (
43+ ctx ,
44+ _input : torch .Tensor ,
45+ weight : torch .Tensor ,
46+ bias : Optional [torch .Tensor ],
47+ activation_config : NVFP4FakeQuantizeConfig ,
48+ weight_config : NVFP4FakeQuantizeConfig ,
49+ ) -> torch .Tensor :
50+ # quantize input activations
51+ if activation_config .use_per_tensor_scale :
52+ tensor_amax = torch .max (torch .abs (_input ))
53+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
54+ else :
55+ per_tensor_scale = None
56+ _input = NVFP4Tensor .to_nvfp4 (
57+ _input ,
58+ per_tensor_scale = per_tensor_scale ,
59+ is_swizzled_scales = activation_config .use_swizzled_scales ,
60+ use_triton_kernel = activation_config .use_triton_kernel ,
61+ )
62+
63+ # quantize weights
64+ if weight_config .use_per_tensor_scale :
65+ tensor_amax = torch .max (torch .abs (weight ))
66+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
67+ else :
68+ per_tensor_scale = None
69+ weight = NVFP4Tensor .to_nvfp4 (
70+ weight ,
71+ per_tensor_scale = per_tensor_scale ,
72+ is_swizzled_scales = weight_config .use_swizzled_scales ,
73+ use_triton_kernel = False ,
74+ )
2975
76+ # Follow `NVFP4InferenceConfig`, always use traditional construction
77+ # for weights and set `use_triton_kernel` afterwards
78+ weight .use_triton_kernel = weight_config .use_triton_kernel
3079
31- class NVFP4FakeQuantizer (FakeQuantizerBase ):
80+ ctx .save_for_backward (_input , weight )
81+
82+ return _addmm_nvfp4_dispatch (
83+ _input ,
84+ weight .t (),
85+ None , # aten_op, not used
86+ bias ,
87+ )
88+
89+ @staticmethod
90+ def backward (ctx , grad_output : torch .Tensor ) -> torch .Tensor :
91+ _input , weight = ctx .saved_tensors
92+ assert isinstance (_input , NVFP4Tensor )
93+ assert isinstance (weight , NVFP4Tensor )
94+ _input = _input .to_dtype (_input ._orig_dtype )
95+ weight = weight .to_dtype (weight ._orig_dtype )
96+ grad_input = torch .mm (grad_output , weight )
97+ grad_weight = torch .mm (grad_output .t (), _input )
98+ return grad_input , grad_weight , None , None , None
99+
100+
101+ class NVFP4FakeQuantizedLinear (torch .nn .Linear ):
32102 """
33- (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
103+ Linear module for fake quantized NVFP4 weights and/or activations.
104+
105+ The forward pass follows quantization and addmm numerics in `NVFP4Tensor`
106+ in lower precision exactly, while the backward pass uses dequantize
107+ (fake quantized) values in high precision.
108+
109+ Currently this is only applicable on Blackwell and future generations.
110+ See https://github.com/pytorch/ao/issues/3102 for more details.
111+
112+ Example usage::
113+
114+ from torchao.quantization import quantize_
115+ from torchao.prototype.mx_formats import NVFP4InferenceConfig
116+
117+ base_config = NVFP4InferenceConfig()
118+ quantize_(model, QATConfig(base_config, step="prepare"))
119+ # Model contains `NVFP4FakeQuantizedLinear` now
120+
121+ train_loop(model)
122+ quantize_(model, QATConfig(base_config, step="convert"))
123+ # Model contains `nn.Linear` with `NVFP4Tensor` weights now
34124 """
35125
36- def __init__ (self , config : NVFP4FakeQuantizeConfig ):
37- super ().__init__ ()
38- torch ._C ._log_api_usage_once ("torchao.quantization.qat.NVFP4FakeQuantizer" )
39- self .config = config
126+ def __init__ (
127+ self ,
128+ in_features : int ,
129+ out_features : int ,
130+ bias : bool = False ,
131+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
132+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
133+ * args ,
134+ ** kwargs ,
135+ ):
136+ super ().__init__ (
137+ in_features ,
138+ out_features ,
139+ bias ,
140+ * args ,
141+ ** kwargs ,
142+ )
143+ if weight_config is None :
144+ raise ValueError ("Must specify `weight_config`" )
145+ if activation_config is None :
146+ raise ValueError ("Weight only NVFP4 QAT not supported yet" )
147+ self .activation_config = activation_config
148+ self .weight_config = weight_config
40149
41150 def forward (self , x : torch .Tensor ) -> torch .Tensor :
42- block_size = 16
43- original_shape = x .shape
44151 if x .dim () == 3 :
152+ batch_size = x .shape [0 ]
45153 x = x .view (- 1 , x .shape [- 1 ])
46- if self .config .use_per_tensor_scale :
47- tensor_amax = torch .max (torch .abs (x ))
48- per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
49154 else :
50- per_tensor_scale = None
155+ batch_size = None
156+ fq = _NVFP4QuantizedForwardFakeQuantizedBackward .apply (
157+ x , self .weight , self .bias , self .activation_config , self .weight_config
158+ )
159+ assert fq .dtype == x .dtype
160+ if batch_size is not None :
161+ return fq .view (batch_size , - 1 , fq .shape [- 1 ])
162+ else :
163+ return fq
51164
52- # quantize
53- scale , q = _nvfp4_quantize (
54- x ,
55- block_size = block_size ,
56- per_tensor_scale = per_tensor_scale ,
57- skip_dtype_cast_and_packing = True ,
165+ @classmethod
166+ def from_linear (
167+ cls ,
168+ mod : torch .nn .Linear ,
169+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
170+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
171+ ):
172+ new_linear = NVFP4FakeQuantizedLinear (
173+ mod .in_features ,
174+ mod .out_features ,
175+ mod .bias is not None ,
176+ activation_config = activation_config ,
177+ weight_config = weight_config ,
178+ device = mod .weight .device ,
179+ dtype = mod .weight .dtype ,
58180 )
59- if self .config .use_per_tensor_scale :
60- scale = scale * per_tensor_scale
61- assert q .dtype == x .dtype
62- assert scale .dtype == torch .float32
63-
64- # dequantize
65- M , K = q .shape [0 ], q .shape [1 ]
66- q = q .view (M , K // block_size , block_size )
67- scale = scale .view (M , K // block_size , 1 )
68- dq = q * scale
69- return dq .view (original_shape ).to (x .dtype )
181+ # In distributed training, the model may be instantiated
182+ # on the meta device, in which case there is no need to
183+ # copy the weights, and doing so will result in an error
184+ if mod .weight .device != torch .device ("meta" ):
185+ new_linear .weight = mod .weight
186+ new_linear .bias = mod .bias
187+ return new_linear
0 commit comments