diff --git a/.github/scripts/bench/bench_op.py b/.github/scripts/bench/bench_op.py index 4135e3ab8..7bbce06e9 100644 --- a/.github/scripts/bench/bench_op.py +++ b/.github/scripts/bench/bench_op.py @@ -15,7 +15,7 @@ def bench_matmul_f16(params: str, *args, **kwargs) -> float: g = hidet.trace_from(c, inputs=[a, b]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) def bench_batch_matmul(params: str, *args, **kwargs) -> float: # Default to benchmarking f32 for now, though this op can run other dtypes @@ -28,7 +28,7 @@ def bench_batch_matmul(params: str, *args, **kwargs) -> float: g = hidet.trace_from(c, inputs=[a, b]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) def bench_conv2d(params: str, *args, **kwargs) -> float: x_shape, w_shape = params.split(',') @@ -40,7 +40,7 @@ def bench_conv2d(params: str, *args, **kwargs) -> float: g = hidet.trace_from(o, inputs=[x, w]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float: x_shape, w_shape = params.split(',') @@ -52,7 +52,7 @@ def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float: g = hidet.trace_from(o, inputs=[x, w]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) def bench_attn(params: str, *args, **kwargs) -> float: bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')] @@ -66,7 +66,7 @@ def bench_attn(params: str, *args, **kwargs) -> float: g = hidet.trace_from(o, inputs=[q, k, v]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) def bench_attn_mask_add(params: str, *args, **kwargs) -> float: bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')] @@ -82,7 +82,7 @@ def bench_attn_mask_add(params: str, *args, **kwargs) -> float: g = hidet.trace_from(o, inputs=[q, k, v, mask]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) def bench_reduce(params: str, *args, **kwargs) -> float: x_shape, axis = params.split(',', maxsplit=1) @@ -95,7 +95,7 @@ def bench_reduce(params: str, *args, **kwargs) -> float: g = hidet.trace_from(o, inputs=[x]) g = hidet.graph.optimize(g) g = g.cuda_graph() - return bench_torch_model(g, []) + return bench_torch_model(lambda: g.run_async(), []) bench_func_map = { 'matmul_f16': bench_matmul_f16, diff --git a/.github/scripts/bench/bench_utils.py b/.github/scripts/bench/bench_utils.py index 09cf862a8..3921eea7a 100644 --- a/.github/scripts/bench/bench_utils.py +++ b/.github/scripts/bench/bench_utils.py @@ -35,9 +35,10 @@ def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10): return latency def enable_compile_server(enable=True): - hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME')) - hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT'))) - hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME')) - hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD')) - hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip()) - hidet.option.compile_server.enable(flag=enable) \ No newline at end of file + if os.environ.get('CI_CS_HOSTNAME'): + hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME')) + hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT'))) + hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME')) + hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD')) + hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip()) + hidet.option.compile_server.enable(flag=enable) \ No newline at end of file