diff --git a/run_benchmark.py b/run_benchmark.py index 63bad54040..62c4894894 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -24,7 +24,7 @@ def list_benchmarks() -> Dict[str, str]: def run(): available_benchmarks = list_benchmarks() - parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark") + parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark", add_help=False) parser.add_argument( "bm_name", choices=available_benchmarks.keys(), diff --git a/torchbenchmark/operators/addmm/operator.py b/torchbenchmark/operators/addmm/operator.py index 0e80f10bca..2c775139c3 100644 --- a/torchbenchmark/operators/addmm/operator.py +++ b/torchbenchmark/operators/addmm/operator.py @@ -1,6 +1,5 @@ -import csv import os -import statistics +import argparse from typing import Any, Callable, Generator, List, Optional, Tuple import numpy @@ -70,8 +69,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "best_config"] DEFAULT_PRECISION = "bf16" - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) addmm_args = parse_args(self.extra_args) if addmm_args.m and addmm_args.n and addmm_args.k: self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)] diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index 0d017ec786..33299564ce 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -107,9 +107,8 @@ def parse_op_args(args: List[str]): class Operator(BenchmarkOperator): DEFAULT_PRECISION = "bf16" - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): - # pass the framework level args (e.g., device, is_training, dtype) to the parent class - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) args = parse_op_args(self.extra_args) self.BATCH = args.batch self.H = args.n_heads diff --git a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py index 5fc29e657b..f25fb31490 100644 --- a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py +++ b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py @@ -7,7 +7,7 @@ from triton.runtime.jit import reinterpret -from typing import Any +from typing import Any, Optional, List from torchbenchmark.util.triton_op import ( BenchmarkOperator, @@ -27,8 +27,8 @@ def parse_args(args): class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency"] - def __init__(self, mode, device, extra_args): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) self.extra_args = parse_args(extra_args) def get_input_iter(self): diff --git a/torchbenchmark/operators/fp8_gemm_blockwise/operator.py b/torchbenchmark/operators/fp8_gemm_blockwise/operator.py index 75a7da2861..a1a46219d1 100644 --- a/torchbenchmark/operators/fp8_gemm_blockwise/operator.py +++ b/torchbenchmark/operators/fp8_gemm_blockwise/operator.py @@ -108,8 +108,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "speedup", "accuracy"] DEFAULT_PRECISION = "fp32" - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) addmm_args = parse_args(self.extra_args) if addmm_args.m and addmm_args.n and addmm_args.k: self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] diff --git a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py index 9f7ed10258..d0bba2632e 100644 --- a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py +++ b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py @@ -85,8 +85,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "speedup", "accuracy"] DEFAULT_PRECISION = "fp32" - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) addmm_args = parse_args(self.extra_args) if addmm_args.m and addmm_args.n and addmm_args.k: self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] diff --git a/torchbenchmark/operators/gather_gemv/operator.py b/torchbenchmark/operators/gather_gemv/operator.py index f665e86a6e..3a8d145041 100644 --- a/torchbenchmark/operators/gather_gemv/operator.py +++ b/torchbenchmark/operators/gather_gemv/operator.py @@ -4,6 +4,7 @@ gather + gemv is the primary kernel driving mixtral perf. """ +import argparse import csv import os import statistics @@ -38,8 +39,8 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): * 1e-6 ) - def __init__(self, mode: str, device: str, extra_args: List[str] = []): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) @register_benchmark(baseline=True) def test_0(self, p1, p2, p3) -> Callable: diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index 31e904e409..3aa7fdfe9d 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -1,3 +1,4 @@ +import argparse import csv import os import statistics @@ -78,8 +79,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops", "best_config"] DEFAULT_PRECISION = "fp16" - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) gemm_args = parse_args(self.extra_args) if gemm_args.input: self.shapes = read_shapes_from_csv(gemm_args.input) diff --git a/torchbenchmark/operators/int4_gemm/int4_gemm.py b/torchbenchmark/operators/int4_gemm/int4_gemm.py index 6390300482..8864de6aef 100644 --- a/torchbenchmark/operators/int4_gemm/int4_gemm.py +++ b/torchbenchmark/operators/int4_gemm/int4_gemm.py @@ -12,7 +12,7 @@ import triton.ops import triton.language as tl -from typing import Any +from typing import Any, Optional, List from torchbenchmark.util.triton_op import ( BenchmarkOperator, @@ -27,8 +27,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency", "best_config"] - def __init__(self, mode, device, extra_args): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) # `Group size` and `inner K tiles` are defaults from gpt-fast. self.group_size = 32 self.inner_k_tiles = 8 diff --git a/torchbenchmark/operators/jagged_mean/operator.py b/torchbenchmark/operators/jagged_mean/operator.py index 75856d94ec..374d1fe525 100644 --- a/torchbenchmark/operators/jagged_mean/operator.py +++ b/torchbenchmark/operators/jagged_mean/operator.py @@ -101,8 +101,8 @@ class Operator(BenchmarkOperator): False # enables GPU/CPU sync (for methods like NestedTensor unbind) ) - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) self.sizes = list(range(2, 12, 4)) + list( range(12, 23, 3) ) # bias towards larger sizes, which are more representative of real-world shapes diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py index 1be5fff261..78120888ec 100644 --- a/torchbenchmark/operators/jagged_sum/operator.py +++ b/torchbenchmark/operators/jagged_sum/operator.py @@ -125,8 +125,8 @@ class Operator(BenchmarkOperator): False # enables GPU/CPU sync (for methods like NestedTensor unbind) ) - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) self.sizes = list(range(2, 12, 4)) + list( range(12, 23, 3) ) # bias towards larger sizes, which are more representative of real-world shapes diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py index 1ccf2a0181..6e06111fe1 100644 --- a/torchbenchmark/operators/sum/operator.py +++ b/torchbenchmark/operators/sum/operator.py @@ -151,8 +151,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy", "best_config"] - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) args = parse_op_args(self.extra_args) self.input_dim = args.input_dim self.reduce_dim = args.reduce_dim diff --git a/torchbenchmark/operators/template_attention/operator.py b/torchbenchmark/operators/template_attention/operator.py index bc08dd8915..6ce1e1b891 100644 --- a/torchbenchmark/operators/template_attention/operator.py +++ b/torchbenchmark/operators/template_attention/operator.py @@ -1,4 +1,4 @@ - +import argparse import csv import os import statistics @@ -29,8 +29,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "speedup", "accuracy"] - def __init__(self, mode: str, device: str, extra_args: List[str] = []): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) self.shapes = BUILDIN_SHAPES @register_benchmark(baseline=True) diff --git a/torchbenchmark/operators/test_op/operator.py b/torchbenchmark/operators/test_op/operator.py index bf149aaeb8..db41438a42 100644 --- a/torchbenchmark/operators/test_op/operator.py +++ b/torchbenchmark/operators/test_op/operator.py @@ -1,3 +1,4 @@ +import argparse from typing import Generator, List, Optional import torch @@ -14,8 +15,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["test_metric"] - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) @register_benchmark(label="new_op_label") def test_op(self, x: torch.Tensor): diff --git a/torchbenchmark/operators/welford/operator.py b/torchbenchmark/operators/welford/operator.py index d4877fe129..d5f2906ef5 100644 --- a/torchbenchmark/operators/welford/operator.py +++ b/torchbenchmark/operators/welford/operator.py @@ -1,4 +1,4 @@ - +import argparse import csv import os import statistics @@ -38,8 +38,8 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "speedup", "accuracy"] - def __init__(self, mode: str, device: str, extra_args: List[str] = []): - super().__init__(mode=mode, device=device, extra_args=extra_args) + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): + super().__init__(tb_args, extra_args) self.shapes = BUILDIN_SHAPES @register_benchmark() diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 2e37849972..b180c51c71 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -64,7 +64,7 @@ "fp16": torch.float16, "bf16": torch.bfloat16, } - +_RANGE_NAME = "tritonbench_range" class Mode(Enum): FWD = "fwd" @@ -380,7 +380,6 @@ def _inner(self, *args, **kwargs): return decorator - def register_metric( # Metrics that only apply to non-baseline impls # E.g., accuracy, speedup @@ -408,100 +407,47 @@ def _inner(self, *args, **kwargs): return decorator - -def parse_args( - default_metrics: List[str], - args: List[str], -) -> Tuple[argparse.Namespace, List[str]]: - parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument( - "--metrics", - default=",".join(default_metrics), - help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.", - ) - parser.add_argument( - "--only", - default=None, - help="Specify one or multiple operator implementations to run." - ) - parser.add_argument( - "--baseline", - type=str, - default=None, - help="Override default baseline." - ) - parser.add_argument( - "--num-inputs", - type=int, - help="Number of example inputs.", - ) - parser.add_argument( - "--keep-going", - action="store_true", - ) - parser.add_argument( - "--input-id", - type=int, - default=0, - help="Specify the start input id to run. " \ - "For example, --input-id 0 runs only the first available input sample." \ - "When used together like --input-id --num-inputs , start from the input id " \ - "and run different inputs." - ) - parser.add_argument( - "--test-only", - action="store_true", - help="Run this under test mode, potentially skipping expensive steps like autotuning." - ) - parser.add_argument( - "--dump-ir", - action="store_true", - help="Dump Triton IR", - ) - return parser.parse_known_args(args) - class PostInitProcessor(type): def __call__(cls, *args, **kwargs): obj = type.__call__(cls, *args, **kwargs) obj.__post__init__() return obj - -_RANGE_NAME = "tritonbench_range" - - class BenchmarkOperator(metaclass=PostInitProcessor): mode: Mode = Mode.FWD test: str = "eval" device: str = "cuda" + # By default, only collect latency metrics + # Each operator can override to define their own default metrics + DEFAULT_METRICS = ["latency"] + required_metrics: List[str] _input_iter: Optional[Generator] = None extra_args: List[str] = [] example_inputs: Any = None use_cuda_graphs: bool = True - # By default, only collect latency metrics - # Each operator can override to define their own default metrics - DEFAULT_METRICS = ["latency"] - """ A base class for adding operators to torch benchmark. """ - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=None): set_random_seed() self.name = _find_op_name_from_module_path(self.__class__.__module__) self._raw_extra_args = copy.deepcopy(extra_args) + self.tb_args = tb_args # we accept both "fwd" and "eval" - if mode == "fwd": + if self.tb_args.mode == "fwd": self.mode = Mode.FWD - elif mode == "fwd_bwd": + elif self.tb_args.mode == "fwd_bwd": self.mode = Mode.FWD_BWD else: assert ( - mode == "bwd" + self.tb_args.mode == "bwd" ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd." self.mode = Mode.BWD - self.dargs, unprocessed_args = parse_decoration_args(self, extra_args) + self.device = tb_args.device + self.required_metrics = list(set(tb_args.metrics.split(","))) if tb_args.metrics else self.DEFAULT_METRICS + self.dargs, self.extra_args = parse_decoration_args(self, extra_args) if self.name not in REGISTERED_X_VALS: REGISTERED_X_VALS[self.name] = "x_val" # This will be changed by the time we apply the decoration args @@ -510,17 +456,11 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None) [x for x in REGISTERED_METRICS.get(self.name, []) if x not in BUILTIN_METRICS] ) self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS)) - self.tb_args, self.extra_args = parse_args( - self.DEFAULT_METRICS, - unprocessed_args - ) if self.tb_args.baseline: BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline - self.required_metrics = list(set(self.tb_args.metrics.split(","))) self._only = _split_params_by_comma(self.tb_args.only) self._input_id = self.tb_args.input_id self._num_inputs = self.tb_args.num_inputs - self.device = device # Run the post initialization def __post__init__(self): diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py index e1581c37ba..f7e5185510 100644 --- a/userbenchmark/triton/run.py +++ b/userbenchmark/triton/run.py @@ -8,49 +8,132 @@ from torchbenchmark.operators import load_opbench_by_name from torchbenchmark.util.triton_op import ( - BenchmarkOperatorResult, DEFAULT_RUN_ITERS, DEFAULT_WARMUP, ) -if not hasattr(torch.version, "git_version"): - from pytorch.benchmark.fb.run_utils import usage_report_logger -else: + +try: + if not hasattr(torch.version, "git_version"): + from pytorch.benchmark.fb.run_utils import usage_report_logger + else: + usage_report_logger = lambda *args, **kwargs: None +except ImportError: usage_report_logger = lambda *args, **kwargs: None TRITON_BENCH_CSV_DUMP_PATH = tempfile.gettempdir() + "/tritonbench/" - -def parse_args(args): +def get_parser(): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument("--op", type=str, default=None, help="Operator to benchmark.") + parser.add_argument( + "--op", + type=str, + required=True, + help="Operator to benchmark." + ) parser.add_argument( "--mode", choices=["fwd", "bwd", "fwd_bwd"], default="fwd", help="Test mode (fwd, bwd, or fwd_bwd).", ) - parser.add_argument("--bwd", action="store_true", help="Run backward pass.") parser.add_argument( - "--fwd_bwd", action="store_true", help="Run both forward and backward pass." + "--bwd", + action="store_true", + help="Run backward pass." + ) + parser.add_argument( + "--fwd_bwd", + action="store_true", + help="Run both forward and backward pass.", + ) + parser.add_argument( + "--device", + default="cuda", + help="Device to benchmark.", ) - parser.add_argument("--device", default="cuda", help="Device to benchmark.") parser.add_argument( "--warmup", default=DEFAULT_WARMUP, help="Num of warmup runs for reach benchmark run.", ) parser.add_argument( - "--iter", default=DEFAULT_RUN_ITERS, help="Num of reps for each benchmark run." + "--iter", + default=DEFAULT_RUN_ITERS, + help="Num of reps for each benchmark run.", + ) + parser.add_argument( + "--csv", + action="store_true", + help="Print result as csv.", + ) + parser.add_argument( + "--dump-csv", + action="store_true", + help="Dump result as csv.", + ) + parser.add_argument( + "--skip-print", + action="store_true", + help="Skip printing result.", + ) + parser.add_argument( + "--plot", + action="store_true", + help="Plot the result.", + ) + parser.add_argument( + "--ci", + action="store_true", + help="Run in the CI mode.", + ) + parser.add_argument( + "--metrics", + default=None, + help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.", + ) + parser.add_argument( + "--only", + default=None, + help="Specify one or multiple operator implementations to run." + ) + parser.add_argument( + "--baseline", + type=str, + default=None, + help="Override default baseline." + ) + parser.add_argument( + "--num-inputs", + type=int, + help="Number of example inputs.", + ) + parser.add_argument( + "--keep-going", + action="store_true", + ) + parser.add_argument( + "--input-id", + type=int, + default=0, + help="Specify the start input id to run. " \ + "For example, --input-id 0 runs only the first available input sample." \ + "When used together like --input-id --num-inputs , start from the input id " \ + "and run different inputs." + ) + parser.add_argument( + "--test-only", + action="store_true", + help="Run this under test mode, potentially skipping expensive steps like autotuning." + ) + parser.add_argument( + "--dump-ir", + action="store_true", + help="Dump Triton IR", ) - parser.add_argument("--csv", action="store_true", help="Print result as csv.") - parser.add_argument("--dump-csv", action="store_true", help="Dump result as csv.") - parser.add_argument("--skip-print", action="store_true", help="Skip printing result.") - parser.add_argument("--plot", action="store_true", help="Plot the result.") if not hasattr(torch_version, "git_version"): parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.") - parser.add_argument("--ci", action="store_true", help="Run in the CI mode.") - return parser.parse_known_args(args) + return parser def _run(args: argparse.Namespace, extra_args: List[str]) -> None: Opbench = load_opbench_by_name(args.op) @@ -59,8 +142,7 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None: if args.bwd: args.mode = "bwd" opbench = Opbench( - mode=args.mode, - device=args.device, + tb_args=args, extra_args=extra_args, ) try: @@ -74,7 +156,6 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None: print(metrics) if not hasattr(torch_version, "git_version") and args.log_scuba: from userbenchmark.triton.fb import log_benchmark - log_benchmark(metrics) if args.plot: try: @@ -92,7 +173,8 @@ def run(args: List[str] = []): args = sys.argv[1:] # Log the tool usage usage_report_logger(benchmark_name="tritonbench") - args, extra_args = parse_args(args) + parser = get_parser() + args, extra_args = parser.parse_known_args(args) if args.ci: from .ci import run_ci run_ci()