From 32f6e54bf58a9fe7c558128fdaa68b9533e83220 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 14 May 2024 15:26:29 -0700 Subject: [PATCH] Fix CI after do_bench refactor in pytorch inductor (#242) Summary: We are relying on some private APIs from inductor and a recent refactor: https://github.com/pytorch/pytorch/pull/125736 broken the do_bench API we rely on for autoquant, maybe we should use our own do_bench or rely on triton's directly? Test Plan: regression tests python test/integration/test_integration.py -k test_autoquant_one_input_29_cuda Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/autoquant.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index fc38c04169..4331d9b042 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -15,6 +15,8 @@ except: from torch._inductor.runtime.runtime_utils import do_bench +from .utils import TORCH_VERSION_AFTER_2_4 + aten = torch.ops.aten AUTOQUANT_CACHE = {} @@ -197,7 +199,11 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + if TORCH_VERSION_AFTER_2_4: + from torch._inductor.runtime.runtime_utils import do_bench_gpu + res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + else: + res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") return res def _is_interpolate_mode(mode):