Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import torchao
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
)
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Expand Down Expand Up @@ -170,15 +174,43 @@ def run_evaluation(
quantize_(
model,
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear)
and fqn != "output",
)
if quantization == "float8_a1x128_w128x128":
config = Float8DynamicActivationFloat8WeightConfig(
granularity=(PerBlock([1, 128]), PerBlock([128, 128])),
activation_value_lb=1e-12,
)
# TODO(future): all workflows in this file should be skipping quantization
# of `lm_head`
# of `lm_head`/`output`
quantize_(model, config)
if quantization == "mxfp8":
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
)
# TODO(future): all workflows in this file should be skipping quantization
# of `lm_head`/`output`
quantize_(
model,
config,
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear)
and fqn != "output",
)
if quantization == "nvfp4":
config = NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=True,
use_triton_kernel=True,
)
# TODO(future): all workflows in this file should be skipping quantization
# of `lm_head`/`output`
quantize_(
model,
config,
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear)
and fqn != "output",
)
if "autoround" in quantization:
from transformers import AutoTokenizer

Expand Down Expand Up @@ -284,8 +316,8 @@ def run_evaluation(

if compile:
# TODO(future PR): clean this up
if quantization == "float8_a1x128_w128x128":
# we don't need max-autotune for float8 blockwise quant
if quantization in ("float8_a1x128_w128x128", "mxfp8", "nvfp4"):
# we don't need max-autotune for float8 blockwise or mxfp8 quant
model = torch.compile(model)
else:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
23 changes: 22 additions & 1 deletion torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,28 @@ To reproduce this on supported hardware, you can run the following command:

## inference

Coming soon!
Eval results on LLaMa 3.1 8B on common tasks. `mxfp8` and `nvfp4` recipes quantize all linears except `lm_head`.

Note: the accuracy results below are WIP and are not optimized yet.

| recipe | wikitext word_perplexity | winogrande |
| ------ | -------- | ---------- |
| bfloat16 (baseline) | 7.5472105433748435 | 0.7426992896606156 |
| mxfp8 | 7.609070006132819 | 0.7292817679558011 |
| nvfp4 | 8.44478255417328 | 0.7182320441988951 |

To reproduce:

```bash
# baseline
python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande

# mxfp8
python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization mxfp8

# nvfp4
python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization nvfp4
```

# testing

Expand Down
Loading