Skip to content

Commit

Permalink
Fix CI after do_bench refactor in pytorch inductor (pytorch#242)
Browse files Browse the repository at this point in the history
Summary:
We are relying on some private APIs from inductor and a recent refactor: pytorch/pytorch#125736 broken the do_bench
API we rely on for autoquant, maybe we should use our own do_bench or rely on triton's directly?

Test Plan:
regression tests
python test/integration/test_integration.py -k test_autoquant_one_input_29_cuda

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 14, 2024
1 parent 530d789 commit 32f6e54
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
except:
from torch._inductor.runtime.runtime_utils import do_bench

from .utils import TORCH_VERSION_AFTER_2_4

aten = torch.ops.aten

AUTOQUANT_CACHE = {}
Expand Down Expand Up @@ -197,7 +199,11 @@ def do_autoquant_bench(op, *args, **kwargs):
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
if TORCH_VERSION_AFTER_2_4:
from torch._inductor.runtime.runtime_utils import do_bench_gpu
res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
else:
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
return res

def _is_interpolate_mode(mode):
Expand Down

0 comments on commit 32f6e54

Please sign in to comment.