Skip to content

Commit

Permalink
Add torchao to PT2 Benchmark Runner (#2268)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#126469


Support torchao performance and accuracy tests in PT2 Benchmark Runner, using the inductor backend as the baseline.

Reviewed By: jerryzh168

Differential Revision: D57463273
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed May 16, 2024
1 parent 218fffe commit 83e955c
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3467,6 +3467,12 @@ def get_example_inputs(self):
action="store_true",
help="Measure speedup with TorchInductor",
)
group.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly", "autoquant", "noquant"],
default=None,
help="Measure speedup of torchao quantization with TorchInductor baseline",
)
group.add_argument(
"--export",
action="store_true",
Expand Down Expand Up @@ -3661,6 +3667,9 @@ def run(runner, args, original_dir=None):
if args.inductor:
assert args.backend is None
args.backend = "inductor"
if args.quantization:
assert args.backend is None
args.backend = "torchao"
if args.dynamic_batch_only:
args.dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True
Expand Down Expand Up @@ -3939,6 +3948,18 @@ def run(runner, args, original_dir=None):

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
elif args.backend == "torchao":
assert "cuda" in args.devices, "Quantization requires CUDA device."
assert args.bfloat16, "Quantization requires dtype bfloat16."
from .torchao import torchao_optimize_ctx
baseline_ctx = functools.partial(
torch.compile,
backend="inductor",
fullgraph=args.nopython,
mode=args.inductor_compile_mode,
)
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
optimize_ctx = torchao_optimize_ctx(args.quantization)
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = speedup_experiment
Expand Down

0 comments on commit 83e955c

Please sign in to comment.