Skip to content

Commit 1568f88

Browse files
committed
add mxfp8 and nvfp4 to Llama eval scripts
Summary: Adds mxfp8 and nvfp4 to llama eval scripts. Results: ``` // bf16 baseline with-proxy time python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande wikitext: {'alias': 'wikitext', 'word_perplexity,none': 7.5472105433748435, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.459319739134015, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5452960145272896, 'bits_per_byte_stderr,none': 'N/A'} winogrande: {'alias': 'winogrande', 'acc,none': 0.7426992896606156, 'acc_stderr,none': 0.012285989618865697} // mxfp8 with floor scaling, turned off compile as it seemed stuck in coordinate descent tuning with-proxy time python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization mxfp8 wikitext: {'alias': 'wikitext', 'word_perplexity,none': 7.609070006132819, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.4615491037668933, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5474983002838458, 'bits_per_byte_stderr,none': 'N/A'} winogrande: {'alias': 'winogrande', 'acc,none': 0.7292817679558011, 'acc_stderr,none': 0.012487904760626407} // mxfp8 with rceil scaling wikitext: {'alias': 'wikitext', 'word_perplexity,none': 7.605445025927753, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.4614188696390065, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5473697404554175, 'bits_per_byte_stderr,none': 'N/A'} winogrande: {'alias': 'winogrande', 'acc,none': 0.7387529597474349, 'acc_stderr,none': 0.012346914863415201} // nvfp4 wikitext: {'alias': 'wikitext', 'word_perplexity,none': 8.44478255417328, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.4903102070118779, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5756126578938119, 'bits_per_byte_stderr,none': 'N/A'} winogrande: {'alias': 'winogrande', 'acc,none': 0.7182320441988951, 'acc_stderr,none': 0.012643326011853038} // float8 rowwise (for comparison to existing technique) wikitext: {'alias': 'wikitext', 'word_perplexity,none': 7.618818730886612, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.4618990946965715, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5478437349532752, 'bits_per_byte_stderr,none': 'N/A'} winogrande: {'alias': 'winogrande', 'acc,none': 0.7371744277821626, 'acc_stderr,none': 0.01237092252726192} ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3a2d8ef ghstack-comment-id: 3581080988 Pull-Request: #3394
1 parent 16aad7c commit 1568f88

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

torchao/_models/llama/eval.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
import torchao
1818
from torchao._models.llama.model import prepare_inputs_for_model
19+
from torchao.prototype.mx_formats.inference_workflow import (
20+
MXDynamicActivationMXWeightConfig,
21+
NVFP4DynamicActivationNVFP4WeightConfig,
22+
)
1923
from torchao.quantization import (
2024
Float8DynamicActivationFloat8WeightConfig,
2125
Float8WeightOnlyConfig,
@@ -170,15 +174,43 @@ def run_evaluation(
170174
quantize_(
171175
model,
172176
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
177+
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear)
178+
and fqn != "output",
173179
)
174180
if quantization == "float8_a1x128_w128x128":
175181
config = Float8DynamicActivationFloat8WeightConfig(
176182
granularity=(PerBlock([1, 128]), PerBlock([128, 128])),
177183
activation_value_lb=1e-12,
178184
)
179185
# TODO(future): all workflows in this file should be skipping quantization
180-
# of `lm_head`
186+
# of `lm_head`/`output`
181187
quantize_(model, config)
188+
if quantization == "mxfp8":
189+
config = MXDynamicActivationMXWeightConfig(
190+
activation_dtype=torch.float8_e4m3fn,
191+
weight_dtype=torch.float8_e4m3fn,
192+
)
193+
# TODO(future): all workflows in this file should be skipping quantization
194+
# of `lm_head`/`output`
195+
quantize_(
196+
model,
197+
config,
198+
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear)
199+
and fqn != "output",
200+
)
201+
if quantization == "nvfp4":
202+
config = NVFP4DynamicActivationNVFP4WeightConfig(
203+
use_dynamic_per_tensor_scale=True,
204+
use_triton_kernel=True,
205+
)
206+
# TODO(future): all workflows in this file should be skipping quantization
207+
# of `lm_head`/`output`
208+
quantize_(
209+
model,
210+
config,
211+
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear)
212+
and fqn != "output",
213+
)
182214
if "autoround" in quantization:
183215
from transformers import AutoTokenizer
184216

@@ -284,8 +316,8 @@ def run_evaluation(
284316

285317
if compile:
286318
# TODO(future PR): clean this up
287-
if quantization == "float8_a1x128_w128x128":
288-
# we don't need max-autotune for float8 blockwise quant
319+
if quantization in ("float8_a1x128_w128x128", "mxfp8", "nvfp4"):
320+
# we don't need max-autotune for float8 blockwise or mxfp8 quant
289321
model = torch.compile(model)
290322
else:
291323
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/prototype/mx_formats/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,28 @@ To reproduce this on supported hardware, you can run the following command:
223223

224224
## inference
225225

226-
Coming soon!
226+
Eval results on LLaMa 3.1 8B on common tasks. `mxfp8` and `nvfp4` recipes quantize all linears except `lm_head`.
227+
228+
Note: the accuracy results below are WIP and are not optimized yet.
229+
230+
| recipe | wikitext word_perplexity | winogrande |
231+
| ------ | -------- | ---------- |
232+
| bfloat16 (baseline) | 7.5472105433748435 | 0.7426992896606156 |
233+
| mxfp8 | 7.609070006132819 | 0.7292817679558011 |
234+
| nvfp4 | 8.44478255417328 | 0.7182320441988951 |
235+
236+
To reproduce:
237+
238+
```bash
239+
# baseline
240+
python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande
241+
242+
# mxfp8
243+
python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization mxfp8
244+
245+
# nvfp4
246+
python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization nvfp4
247+
```
227248

228249
# testing
229250

0 commit comments

Comments
 (0)