11from dataclasses import dataclass
2+ from typing import Optional
23
34import torch
45
56from torchao .prototype .mx_formats .nvfp4_tensor import (
6- _nvfp4_quantize ,
7+ NVFP4Tensor ,
8+ _addmm_nvfp4_dispatch ,
79 per_tensor_amax_to_scale ,
810)
9- from torchao .quantization .qat import (
10- FakeQuantizeConfigBase ,
11- FakeQuantizerBase ,
12- )
11+ from torchao .quantization .qat import FakeQuantizeConfigBase
1312
1413
1514@dataclass
@@ -23,47 +22,159 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322 Args:
2423 use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524 after the initial fp8 (e4m3) block-wise scaling (default True)
25+ use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+ use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627 """
2728
2829 use_per_tensor_scale : bool = True
30+ use_swizzled_scales : bool = False
31+ use_triton_kernel : bool = False
2932
3033
31- class NVFP4FakeQuantizer ( FakeQuantizerBase ):
34+ class _NVFP4FakeQuantizedLinearForward ( torch . autograd . Function ):
3235 """
33- (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config .
36+ Autograd function for NVFP4 fake quantization + addmm .
3437 """
3538
36- def __init__ (self , config : NVFP4FakeQuantizeConfig ):
37- super ().__init__ ()
38- torch ._C ._log_api_usage_once ("torchao.quantization.qat.NVFP4FakeQuantizer" )
39- self .config = config
39+ @staticmethod
40+ def forward (
41+ ctx ,
42+ _input : torch .Tensor ,
43+ weight : torch .Tensor ,
44+ bias : Optional [torch .Tensor ],
45+ activation_config : NVFP4FakeQuantizeConfig ,
46+ weight_config : NVFP4FakeQuantizeConfig ,
47+ ) -> torch .Tensor :
48+ # quantize input activations
49+ if activation_config .use_per_tensor_scale :
50+ tensor_amax = torch .max (torch .abs (_input ))
51+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
52+ else :
53+ per_tensor_scale = None
54+ _input = NVFP4Tensor .to_nvfp4 (
55+ _input ,
56+ per_tensor_scale = per_tensor_scale ,
57+ is_swizzled_scales = activation_config .use_swizzled_scales ,
58+ use_triton_kernel = activation_config .use_triton_kernel ,
59+ )
60+
61+ # quantize weights
62+ if weight_config .use_per_tensor_scale :
63+ tensor_amax = torch .max (torch .abs (weight ))
64+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
65+ else :
66+ per_tensor_scale = None
67+ weight = NVFP4Tensor .to_nvfp4 (
68+ weight ,
69+ per_tensor_scale = per_tensor_scale ,
70+ is_swizzled_scales = weight_config .use_swizzled_scales ,
71+ use_triton_kernel = False ,
72+ )
73+
74+ # Follow `NVFP4InferenceConfig`, always use traditional construction
75+ # for weights and set `use_triton_kernel` afterwards
76+ weight .use_triton_kernel = weight_config .use_triton_kernel
77+
78+ ctx .save_for_backward (_input , weight )
79+
80+ return _addmm_nvfp4_dispatch (
81+ _input ,
82+ weight .t (),
83+ None , # aten_op, not used
84+ bias ,
85+ )
86+
87+ @staticmethod
88+ def backward (ctx , grad_output : torch .Tensor ) -> torch .Tensor :
89+ _input , weight = ctx .saved_tensors
90+ assert isinstance (_input , NVFP4Tensor )
91+ assert isinstance (weight , NVFP4Tensor )
92+ _input = _input .to_dtype (_input ._orig_dtype )
93+ weight = weight .to_dtype (weight ._orig_dtype )
94+ grad_input = torch .mm (grad_output , weight )
95+ grad_weight = torch .mm (grad_output .t (), _input )
96+ return grad_input , grad_weight , None , None , None
97+
98+
99+ class NVFP4FakeQuantizedLinear (torch .nn .Linear ):
100+ """
101+ Linear module for fake quantized NVFP4 weights and/or activations.
102+
103+ The forward pass follows quantization and addmm numerics in `NVFP4Tensor` exactly.
104+
105+ Example usage::
106+
107+ from torchao.quantization import quantize_
108+ from torchao.prototype.mx_formats import NVFP4InferenceConfig
109+
110+ base_config = NVFP4InferenceConfig()
111+ quantize_(model, QATConfig(base_config, step="prepare"))
112+ # Model contains `NVFP4FakeQuantizedLinear` now
113+
114+ train_loop(model)
115+ quantize_(model, QATConfig(base_config, step="convert"))
116+ # Model contains `nn.Linear` with `NVFP4Tensor` weights now
117+ """
118+
119+ def __init__ (
120+ self ,
121+ in_features : int ,
122+ out_features : int ,
123+ bias : bool = False ,
124+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
125+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
126+ * args ,
127+ ** kwargs ,
128+ ):
129+ super ().__init__ (
130+ in_features ,
131+ out_features ,
132+ bias ,
133+ * args ,
134+ ** kwargs ,
135+ )
136+ if weight_config is None :
137+ raise ValueError ("Must specify `weight_config`" )
138+ if activation_config is None :
139+ raise ValueError ("Weight only NVFP4 QAT not supported yet" )
140+ self .activation_config = activation_config
141+ self .weight_config = weight_config
40142
41143 def forward (self , x : torch .Tensor ) -> torch .Tensor :
42- block_size = 16
43- original_shape = x .shape
44144 if x .dim () == 3 :
145+ batch_size = x .shape [0 ]
45146 x = x .view (- 1 , x .shape [- 1 ])
46- if self .config .use_per_tensor_scale :
47- tensor_amax = torch .max (torch .abs (x ))
48- per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
49147 else :
50- per_tensor_scale = None
148+ batch_size = None
149+ fq = _NVFP4FakeQuantizedLinearForward .apply (
150+ x , self .weight , self .bias , self .activation_config , self .weight_config
151+ )
152+ assert fq .dtype == x .dtype
153+ if batch_size is not None :
154+ return fq .view (batch_size , - 1 , fq .shape [- 1 ])
155+ else :
156+ return fq
51157
52- # quantize
53- scale , q = _nvfp4_quantize (
54- x ,
55- block_size = block_size ,
56- per_tensor_scale = per_tensor_scale ,
57- skip_dtype_cast_and_packing = True ,
158+ @classmethod
159+ def from_linear (
160+ cls ,
161+ mod : torch .nn .Linear ,
162+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
163+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
164+ ):
165+ new_linear = NVFP4FakeQuantizedLinear (
166+ mod .in_features ,
167+ mod .out_features ,
168+ mod .bias is not None ,
169+ activation_config = activation_config ,
170+ weight_config = weight_config ,
171+ device = mod .weight .device ,
172+ dtype = mod .weight .dtype ,
58173 )
59- if self .config .use_per_tensor_scale :
60- scale = scale * per_tensor_scale
61- assert q .dtype == x .dtype
62- assert scale .dtype == torch .float32
63-
64- # dequantize
65- M , K = q .shape [0 ], q .shape [1 ]
66- q = q .view (M , K // block_size , block_size )
67- scale = scale .view (M , K // block_size , 1 )
68- dq = q * scale
69- return dq .view (original_shape ).to (x .dtype )
174+ # In distributed training, the model may be instantiated
175+ # on the meta device, in which case there is no need to
176+ # copy the weights, and doing so will result in an error
177+ if mod .weight .device != torch .device ("meta" ):
178+ new_linear .weight = mod .weight
179+ new_linear .bias = mod .bias
180+ return new_linear
0 commit comments