11from dataclasses import dataclass
2+ from typing import Optional
23
34import torch
45
56from torchao .prototype .mx_formats .nvfp4_tensor import (
6- _nvfp4_quantize ,
7+ NVFP4Tensor ,
8+ _addmm_nvfp4_dispatch ,
79 per_tensor_amax_to_scale ,
810)
9- from torchao .quantization .qat import (
10- FakeQuantizeConfigBase ,
11- FakeQuantizerBase ,
12- )
11+ from torchao .quantization .qat import FakeQuantizeConfigBase
1312
1413
1514@dataclass
@@ -23,47 +22,162 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322 Args:
2423 use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524 after the initial fp8 (e4m3) block-wise scaling (default True)
25+ use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+ use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627 """
2728
2829 use_per_tensor_scale : bool = True
30+ use_swizzled_scales : bool = False
31+ use_triton_kernel : bool = False
2932
3033
31- class NVFP4FakeQuantizer ( FakeQuantizerBase ):
34+ class _NVFP4QuantizedForwardFakeQuantizedBackward ( torch . autograd . Function ):
3235 """
33- (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
36+ Autograd function for NVFP4 quantization + addmm in low precision during forward,
37+ and fake quantization in high precision during backward.
3438 """
3539
36- def __init__ (self , config : NVFP4FakeQuantizeConfig ):
37- super ().__init__ ()
38- torch ._C ._log_api_usage_once ("torchao.quantization.qat.NVFP4FakeQuantizer" )
39- self .config = config
40+ @staticmethod
41+ def forward (
42+ ctx ,
43+ _input : torch .Tensor ,
44+ weight : torch .Tensor ,
45+ bias : Optional [torch .Tensor ],
46+ activation_config : NVFP4FakeQuantizeConfig ,
47+ weight_config : NVFP4FakeQuantizeConfig ,
48+ ) -> torch .Tensor :
49+ # quantize input activations
50+ if activation_config .use_per_tensor_scale :
51+ tensor_amax = torch .max (torch .abs (_input ))
52+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
53+ else :
54+ per_tensor_scale = None
55+ _input = NVFP4Tensor .to_nvfp4 (
56+ _input ,
57+ per_tensor_scale = per_tensor_scale ,
58+ is_swizzled_scales = activation_config .use_swizzled_scales ,
59+ use_triton_kernel = activation_config .use_triton_kernel ,
60+ )
61+
62+ # quantize weights
63+ if weight_config .use_per_tensor_scale :
64+ tensor_amax = torch .max (torch .abs (weight ))
65+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
66+ else :
67+ per_tensor_scale = None
68+ weight = NVFP4Tensor .to_nvfp4 (
69+ weight ,
70+ per_tensor_scale = per_tensor_scale ,
71+ is_swizzled_scales = weight_config .use_swizzled_scales ,
72+ use_triton_kernel = False ,
73+ )
74+
75+ # Follow `NVFP4InferenceConfig`, always use traditional construction
76+ # for weights and set `use_triton_kernel` afterwards
77+ weight .use_triton_kernel = weight_config .use_triton_kernel
78+
79+ ctx .save_for_backward (_input , weight )
80+
81+ return _addmm_nvfp4_dispatch (
82+ _input ,
83+ weight .t (),
84+ None , # aten_op, not used
85+ bias ,
86+ )
87+
88+ @staticmethod
89+ def backward (ctx , grad_output : torch .Tensor ) -> torch .Tensor :
90+ _input , weight = ctx .saved_tensors
91+ assert isinstance (_input , NVFP4Tensor )
92+ assert isinstance (weight , NVFP4Tensor )
93+ _input = _input .to_dtype (_input ._orig_dtype )
94+ weight = weight .to_dtype (weight ._orig_dtype )
95+ grad_input = torch .mm (grad_output , weight )
96+ grad_weight = torch .mm (grad_output .t (), _input )
97+ return grad_input , grad_weight , None , None , None
98+
99+
100+ class NVFP4FakeQuantizedLinear (torch .nn .Linear ):
101+ """
102+ Linear module for fake quantized NVFP4 weights and/or activations.
103+
104+ The forward pass follows quantization and addmm numerics in `NVFP4Tensor`
105+ in lower precision exactly, while the backward pass uses dequantize
106+ (fake quantized) values in high precision.
107+
108+ Example usage::
109+
110+ from torchao.quantization import quantize_
111+ from torchao.prototype.mx_formats import NVFP4InferenceConfig
112+
113+ base_config = NVFP4InferenceConfig()
114+ quantize_(model, QATConfig(base_config, step="prepare"))
115+ # Model contains `NVFP4FakeQuantizedLinear` now
116+
117+ train_loop(model)
118+ quantize_(model, QATConfig(base_config, step="convert"))
119+ # Model contains `nn.Linear` with `NVFP4Tensor` weights now
120+ """
121+
122+ def __init__ (
123+ self ,
124+ in_features : int ,
125+ out_features : int ,
126+ bias : bool = False ,
127+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
128+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
129+ * args ,
130+ ** kwargs ,
131+ ):
132+ super ().__init__ (
133+ in_features ,
134+ out_features ,
135+ bias ,
136+ * args ,
137+ ** kwargs ,
138+ )
139+ if weight_config is None :
140+ raise ValueError ("Must specify `weight_config`" )
141+ if activation_config is None :
142+ raise ValueError ("Weight only NVFP4 QAT not supported yet" )
143+ self .activation_config = activation_config
144+ self .weight_config = weight_config
40145
41146 def forward (self , x : torch .Tensor ) -> torch .Tensor :
42- block_size = 16
43- original_shape = x .shape
44147 if x .dim () == 3 :
148+ batch_size = x .shape [0 ]
45149 x = x .view (- 1 , x .shape [- 1 ])
46- if self .config .use_per_tensor_scale :
47- tensor_amax = torch .max (torch .abs (x ))
48- per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
49150 else :
50- per_tensor_scale = None
151+ batch_size = None
152+ fq = _NVFP4QuantizedForwardFakeQuantizedBackward .apply (
153+ x , self .weight , self .bias , self .activation_config , self .weight_config
154+ )
155+ assert fq .dtype == x .dtype
156+ if batch_size is not None :
157+ return fq .view (batch_size , - 1 , fq .shape [- 1 ])
158+ else :
159+ return fq
51160
52- # quantize
53- scale , q = _nvfp4_quantize (
54- x ,
55- block_size = block_size ,
56- per_tensor_scale = per_tensor_scale ,
57- skip_dtype_cast_and_packing = True ,
161+ @classmethod
162+ def from_linear (
163+ cls ,
164+ mod : torch .nn .Linear ,
165+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
166+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
167+ ):
168+ new_linear = NVFP4FakeQuantizedLinear (
169+ mod .in_features ,
170+ mod .out_features ,
171+ mod .bias is not None ,
172+ activation_config = activation_config ,
173+ weight_config = weight_config ,
174+ device = mod .weight .device ,
175+ dtype = mod .weight .dtype ,
58176 )
59- if self .config .use_per_tensor_scale :
60- scale = scale * per_tensor_scale
61- assert q .dtype == x .dtype
62- assert scale .dtype == torch .float32
63-
64- # dequantize
65- M , K = q .shape [0 ], q .shape [1 ]
66- q = q .view (M , K // block_size , block_size )
67- scale = scale .view (M , K // block_size , 1 )
68- dq = q * scale
69- return dq .view (original_shape ).to (x .dtype )
177+ # In distributed training, the model may be instantiated
178+ # on the meta device, in which case there is no need to
179+ # copy the weights, and doing so will result in an error
180+ if mod .weight .device != torch .device ("meta" ):
181+ new_linear .weight = mod .weight
182+ new_linear .bias = mod .bias
183+ return new_linear
0 commit comments