|
16 | 16 |
|
17 | 17 | import torchao |
18 | 18 | from torchao._models.llama.model import prepare_inputs_for_model |
| 19 | +from torchao.prototype.mx_formats.inference_workflow import ( |
| 20 | + MXDynamicActivationMXWeightConfig, |
| 21 | + NVFP4DynamicActivationNVFP4WeightConfig, |
| 22 | +) |
19 | 23 | from torchao.quantization import ( |
20 | 24 | Float8DynamicActivationFloat8WeightConfig, |
21 | 25 | Float8WeightOnlyConfig, |
@@ -170,15 +174,43 @@ def run_evaluation( |
170 | 174 | quantize_( |
171 | 175 | model, |
172 | 176 | Float8DynamicActivationFloat8WeightConfig(granularity=granularity), |
| 177 | + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) |
| 178 | + and fqn != "output", |
173 | 179 | ) |
174 | 180 | if quantization == "float8_a1x128_w128x128": |
175 | 181 | config = Float8DynamicActivationFloat8WeightConfig( |
176 | 182 | granularity=(PerBlock([1, 128]), PerBlock([128, 128])), |
177 | 183 | activation_value_lb=1e-12, |
178 | 184 | ) |
179 | 185 | # TODO(future): all workflows in this file should be skipping quantization |
180 | | - # of `lm_head` |
| 186 | + # of `lm_head`/`output` |
181 | 187 | quantize_(model, config) |
| 188 | + if quantization == "mxfp8": |
| 189 | + config = MXDynamicActivationMXWeightConfig( |
| 190 | + activation_dtype=torch.float8_e4m3fn, |
| 191 | + weight_dtype=torch.float8_e4m3fn, |
| 192 | + ) |
| 193 | + # TODO(future): all workflows in this file should be skipping quantization |
| 194 | + # of `lm_head`/`output` |
| 195 | + quantize_( |
| 196 | + model, |
| 197 | + config, |
| 198 | + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) |
| 199 | + and fqn != "output", |
| 200 | + ) |
| 201 | + if quantization == "nvfp4": |
| 202 | + config = NVFP4DynamicActivationNVFP4WeightConfig( |
| 203 | + use_dynamic_per_tensor_scale=True, |
| 204 | + use_triton_kernel=True, |
| 205 | + ) |
| 206 | + # TODO(future): all workflows in this file should be skipping quantization |
| 207 | + # of `lm_head`/`output` |
| 208 | + quantize_( |
| 209 | + model, |
| 210 | + config, |
| 211 | + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) |
| 212 | + and fqn != "output", |
| 213 | + ) |
182 | 214 | if "autoround" in quantization: |
183 | 215 | from transformers import AutoTokenizer |
184 | 216 |
|
@@ -284,8 +316,8 @@ def run_evaluation( |
284 | 316 |
|
285 | 317 | if compile: |
286 | 318 | # TODO(future PR): clean this up |
287 | | - if quantization == "float8_a1x128_w128x128": |
288 | | - # we don't need max-autotune for float8 blockwise quant |
| 319 | + if quantization in ("float8_a1x128_w128x128", "mxfp8", "nvfp4"): |
| 320 | + # we don't need max-autotune for float8 blockwise or mxfp8 quant |
289 | 321 | model = torch.compile(model) |
290 | 322 | else: |
291 | 323 | model = torch.compile(model, mode="max-autotune", fullgraph=True) |
|
0 commit comments