Skip to content

Commit bb52940

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add help for run_benchmark (#2361)
Summary: Show more help messages for Tritonbench ``` $ python run_benchmark.py triton --help usage: run_benchmark.py [-h] [--op OP] [--mode {fwd,bwd,fwd_bwd}] [--bwd] [--fwd_bwd] [--device DEVICE] [--warmup WARMUP] [--iter ITER] [--csv] [--dump-csv] [--skip-print] [--plot] [--ci] [--metrics METRICS] [--only ONLY] [--baseline BASELINE] [--num-inputs NUM_INPUTS] [--keep-going] [--input-id INPUT_ID] [--test-only] [--dump-ir] options: -h, --help show this help message and exit --op OP Operator to benchmark. --mode {fwd,bwd,fwd_bwd} Test mode (fwd, bwd, or fwd_bwd). --bwd Run backward pass. --fwd_bwd Run both forward and backward pass. --device DEVICE Device to benchmark. --warmup WARMUP Num of warmup runs for reach benchmark run. --iter ITER Num of reps for each benchmark run. --csv Print result as csv. --dump-csv Dump result as csv. --skip-print Skip printing result. --plot Plot the result. --ci Run in the CI mode. --metrics METRICS Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup. --only ONLY Specify one or multiple operator implementations to run. --baseline BASELINE Override default baseline. --num-inputs NUM_INPUTS Number of example inputs. --keep-going --input-id INPUT_ID Specify the start input id to run. For example, --input-id 0 runs only the first available input sample.When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> and run <Y> different inputs. --test-only Run this under test mode, potentially skipping expensive steps like autotuning. --dump-ir Dump Triton IR ``` ``` $ python run_benchmark.py triton --op gemm --num-inputs 1 --only triton_tutorial_matmul (M, N, K) triton_tutorial_matmul-latency --------------- -------------------------------- (256, 256, 256) 0.0033702 ``` Pull Request resolved: #2361 Reviewed By: jananisriram Differential Revision: D59374656 Pulled By: xuzhao9 fbshipit-source-id: 139f865895d7550a3475a1a8b4bed037a9ecc769
1 parent 555d05a commit bb52940

File tree

17 files changed

+153
-130
lines changed

17 files changed

+153
-130
lines changed

run_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def list_benchmarks() -> Dict[str, str]:
2424

2525
def run():
2626
available_benchmarks = list_benchmarks()
27-
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark")
27+
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark", add_help=False)
2828
parser.add_argument(
2929
"bm_name",
3030
choices=available_benchmarks.keys(),

torchbenchmark/operators/addmm/operator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import csv
21
import os
3-
import statistics
2+
import argparse
43
from typing import Any, Callable, Generator, List, Optional, Tuple
54

65
import numpy
@@ -70,8 +69,8 @@ class Operator(BenchmarkOperator):
7069
DEFAULT_METRICS = ["tflops", "best_config"]
7170
DEFAULT_PRECISION = "bf16"
7271

73-
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
74-
super().__init__(mode=mode, device=device, extra_args=extra_args)
72+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
73+
super().__init__(tb_args, extra_args)
7574
addmm_args = parse_args(self.extra_args)
7675
if addmm_args.m and addmm_args.n and addmm_args.k:
7776
self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]

torchbenchmark/operators/flash_attention/operator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@ def parse_op_args(args: List[str]):
107107
class Operator(BenchmarkOperator):
108108
DEFAULT_PRECISION = "bf16"
109109

110-
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
111-
# pass the framework level args (e.g., device, is_training, dtype) to the parent class
112-
super().__init__(mode=mode, device=device, extra_args=extra_args)
110+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
111+
super().__init__(tb_args, extra_args)
113112
args = parse_op_args(self.extra_args)
114113
self.BATCH = args.batch
115114
self.H = args.n_heads

torchbenchmark/operators/fp8_gemm/fp8_gemm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from triton.runtime.jit import reinterpret
99

10-
from typing import Any
10+
from typing import Any, Optional, List
1111

1212
from torchbenchmark.util.triton_op import (
1313
BenchmarkOperator,
@@ -27,8 +27,8 @@ def parse_args(args):
2727
class Operator(BenchmarkOperator):
2828
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
2929

30-
def __init__(self, mode, device, extra_args):
31-
super().__init__(mode=mode, device=device, extra_args=extra_args)
30+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
31+
super().__init__(tb_args, extra_args)
3232
self.extra_args = parse_args(extra_args)
3333

3434
def get_input_iter(self):

torchbenchmark/operators/fp8_gemm_blockwise/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ class Operator(BenchmarkOperator):
108108
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
109109
DEFAULT_PRECISION = "fp32"
110110

111-
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
112-
super().__init__(mode=mode, device=device, extra_args=extra_args)
111+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
112+
super().__init__(tb_args, extra_args)
113113
addmm_args = parse_args(self.extra_args)
114114
if addmm_args.m and addmm_args.n and addmm_args.k:
115115
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]

torchbenchmark/operators/fp8_gemm_rowwise/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ class Operator(BenchmarkOperator):
8585
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
8686
DEFAULT_PRECISION = "fp32"
8787

88-
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
89-
super().__init__(mode=mode, device=device, extra_args=extra_args)
88+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
89+
super().__init__(tb_args, extra_args)
9090
addmm_args = parse_args(self.extra_args)
9191
if addmm_args.m and addmm_args.n and addmm_args.k:
9292
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]

torchbenchmark/operators/gather_gemv/operator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
gather + gemv is the primary kernel driving mixtral perf.
55
"""
66

7+
import argparse
78
import csv
89
import os
910
import statistics
@@ -38,8 +39,8 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
3839
* 1e-6
3940
)
4041

41-
def __init__(self, mode: str, device: str, extra_args: List[str] = []):
42-
super().__init__(mode=mode, device=device, extra_args=extra_args)
42+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
43+
super().__init__(tb_args, extra_args)
4344

4445
@register_benchmark(baseline=True)
4546
def test_0(self, p1, p2, p3) -> Callable:

torchbenchmark/operators/gemm/operator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import csv
23
import os
34
import statistics
@@ -78,8 +79,8 @@ class Operator(BenchmarkOperator):
7879
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops", "best_config"]
7980
DEFAULT_PRECISION = "fp16"
8081

81-
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
82-
super().__init__(mode=mode, device=device, extra_args=extra_args)
82+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
83+
super().__init__(tb_args, extra_args)
8384
gemm_args = parse_args(self.extra_args)
8485
if gemm_args.input:
8586
self.shapes = read_shapes_from_csv(gemm_args.input)

torchbenchmark/operators/int4_gemm/int4_gemm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import triton.ops
1313
import triton.language as tl
1414

15-
from typing import Any
15+
from typing import Any, Optional, List
1616

1717
from torchbenchmark.util.triton_op import (
1818
BenchmarkOperator,
@@ -27,8 +27,8 @@
2727
class Operator(BenchmarkOperator):
2828
DEFAULT_METRICS = ["tflops", "gbps", "latency", "best_config"]
2929

30-
def __init__(self, mode, device, extra_args):
31-
super().__init__(mode=mode, device=device, extra_args=extra_args)
30+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
31+
super().__init__(tb_args, extra_args)
3232
# `Group size` and `inner K tiles` are defaults from gpt-fast.
3333
self.group_size = 32
3434
self.inner_k_tiles = 8

torchbenchmark/operators/jagged_mean/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ class Operator(BenchmarkOperator):
101101
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
102102
)
103103

104-
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
105-
super().__init__(mode=mode, device=device, extra_args=extra_args)
104+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
105+
super().__init__(tb_args, extra_args)
106106
self.sizes = list(range(2, 12, 4)) + list(
107107
range(12, 23, 3)
108108
) # bias towards larger sizes, which are more representative of real-world shapes

0 commit comments

Comments
 (0)