11from dataclasses import dataclass
2+ from typing import Optional
23
34import torch
45
56from torchao .prototype .mx_formats .nvfp4_tensor import (
6- _nvfp4_quantize ,
7+ NVFP4Tensor ,
8+ _addmm_nvfp4_dispatch ,
79 per_tensor_amax_to_scale ,
810)
9- from torchao .quantization .qat import (
10- FakeQuantizeConfigBase ,
11- FakeQuantizerBase ,
12- )
11+ from torchao .quantization .qat import FakeQuantizeConfigBase
1312
1413
1514@dataclass
@@ -23,47 +22,144 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322 Args:
2423 use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524 after the initial fp8 (e4m3) block-wise scaling (default True)
25+ use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+ use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627 """
2728
2829 use_per_tensor_scale : bool = True
30+ use_swizzled_scales : bool = False
31+ use_triton_kernel : bool = False
32+
33+
34+ class _NVFP4FakeQuantizedForward (torch .autograd .Function ):
35+ """
36+ TODO: write me
37+ """
38+
39+ @staticmethod
40+ def forward (
41+ ctx ,
42+ _input : torch .Tensor ,
43+ weight : torch .Tensor ,
44+ bias : Optional [torch .Tensor ],
45+ activation_config : NVFP4FakeQuantizeConfig ,
46+ weight_config : NVFP4FakeQuantizeConfig ,
47+ ) -> torch .Tensor :
48+ # quantize input activations
49+ if activation_config .use_per_tensor_scale :
50+ tensor_amax = torch .max (torch .abs (_input ))
51+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
52+ else :
53+ per_tensor_scale = None
54+ _input = NVFP4Tensor .to_nvfp4 (
55+ _input ,
56+ per_tensor_scale = per_tensor_scale ,
57+ is_swizzled_scales = activation_config .use_swizzled_scales ,
58+ use_triton_kernel = activation_config .use_triton_kernel ,
59+ )
60+
61+ # quantize weights
62+ if weight_config .use_per_tensor_scale :
63+ tensor_amax = torch .max (torch .abs (weight ))
64+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
65+ else :
66+ per_tensor_scale = None
67+ weight = NVFP4Tensor .to_nvfp4 (
68+ weight ,
69+ per_tensor_scale = per_tensor_scale ,
70+ is_swizzled_scales = weight_config .use_swizzled_scales ,
71+ use_triton_kernel = False ,
72+ )
73+
74+ # Follow `NVFP4InferenceConfig`, always use traditional construction
75+ # for weights and set `use_triton_kernel` afterwards
76+ weight .use_triton_kernel = weight_config .use_triton_kernel
77+
78+ ctx .save_for_backward (_input , weight )
79+
80+ return _addmm_nvfp4_dispatch (
81+ _input ,
82+ weight .t (),
83+ None , # aten_op, not used
84+ bias ,
85+ )
2986
87+ @staticmethod
88+ def backward (ctx , grad_output : torch .Tensor ) -> torch .Tensor :
89+ _input , weight = ctx .saved_tensors
90+ assert isinstance (_input , NVFP4Tensor )
91+ assert isinstance (weight , NVFP4Tensor )
92+ _input = _input .to_dtype (_input ._orig_dtype )
93+ weight = weight .to_dtype (weight ._orig_dtype )
94+ grad_input = torch .mm (grad_output , weight )
95+ grad_weight = torch .mm (grad_output .t (), _input )
96+ return grad_input , grad_weight , None , None , None
3097
31- class NVFP4FakeQuantizer (FakeQuantizerBase ):
98+
99+ class NVFP4FakeQuantizedLinear (torch .nn .Linear ):
32100 """
33- (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101+ TODO: write me
34102 """
35103
36- def __init__ (self , config : NVFP4FakeQuantizeConfig ):
37- super ().__init__ ()
38- torch ._C ._log_api_usage_once ("torchao.quantization.qat.NVFP4FakeQuantizer" )
39- self .config = config
104+ def __init__ (
105+ self ,
106+ in_features : int ,
107+ out_features : int ,
108+ bias : bool = False ,
109+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
110+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
111+ * args ,
112+ ** kwargs ,
113+ ):
114+ super ().__init__ (
115+ in_features ,
116+ out_features ,
117+ bias ,
118+ * args ,
119+ ** kwargs ,
120+ )
121+ if weight_config is None :
122+ raise ValueError ("Must specify `weight_config`" )
123+ if activation_config is None :
124+ raise ValueError ("Weight only NVFP4 QAT not supported yet" )
125+ self .activation_config = activation_config
126+ self .weight_config = weight_config
40127
41128 def forward (self , x : torch .Tensor ) -> torch .Tensor :
42- block_size = 16
43- original_shape = x .shape
44129 if x .dim () == 3 :
130+ batch_size = x .shape [0 ]
45131 x = x .view (- 1 , x .shape [- 1 ])
46- if self .config .use_per_tensor_scale :
47- tensor_amax = torch .max (torch .abs (x ))
48- per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
49132 else :
50- per_tensor_scale = None
133+ batch_size = None
134+ fq = _NVFP4FakeQuantizedForward .apply (
135+ x , self .weight , self .bias , self .activation_config , self .weight_config
136+ )
137+ assert fq .dtype == x .dtype
138+ if batch_size is not None :
139+ return fq .view (batch_size , - 1 , fq .shape [- 1 ])
140+ else :
141+ return fq
51142
52- # quantize
53- scale , q = _nvfp4_quantize (
54- x ,
55- block_size = block_size ,
56- per_tensor_scale = per_tensor_scale ,
57- skip_dtype_cast_and_packing = True ,
143+ @classmethod
144+ def from_linear (
145+ cls ,
146+ mod : torch .nn .Linear ,
147+ activation_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
148+ weight_config : Optional [NVFP4FakeQuantizeConfig ] = None ,
149+ ):
150+ new_linear = NVFP4FakeQuantizedLinear (
151+ mod .in_features ,
152+ mod .out_features ,
153+ mod .bias is not None ,
154+ activation_config = activation_config ,
155+ weight_config = weight_config ,
156+ device = mod .weight .device ,
157+ dtype = mod .weight .dtype ,
58158 )
59- if self .config .use_per_tensor_scale :
60- scale = scale * per_tensor_scale
61- assert q .dtype == x .dtype
62- assert scale .dtype == torch .float32
63-
64- # dequantize
65- M , K = q .shape [0 ], q .shape [1 ]
66- q = q .view (M , K // block_size , block_size )
67- scale = scale .view (M , K // block_size , 1 )
68- dq = q * scale
69- return dq .view (original_shape ).to (x .dtype )
159+ # In distributed training, the model may be instantiated
160+ # on the meta device, in which case there is no need to
161+ # copy the weights, and doing so will result in an error
162+ if mod .weight .device != torch .device ("meta" ):
163+ new_linear .weight = mod .weight
164+ new_linear .bias = mod .bias
165+ return new_linear
0 commit comments