Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add help for run_benchmark #2361

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def list_benchmarks() -> Dict[str, str]:

def run():
available_benchmarks = list_benchmarks()
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark")
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark", add_help=False)
parser.add_argument(
"bm_name",
choices=available_benchmarks.keys(),
Expand Down
7 changes: 3 additions & 4 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import os
import statistics
import argparse
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
Expand Down Expand Up @@ -71,8 +70,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops"]
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]
Expand Down
5 changes: 2 additions & 3 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def parse_op_args(args: List[str]):
class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
# pass the framework level args (e.g., device, is_training, dtype) to the parent class
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
args = parse_op_args(self.extra_args)
self.BATCH = args.batch
self.H = args.n_heads
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from triton.runtime.jit import reinterpret

from typing import Any
from typing import Any, Optional, List

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand All @@ -27,8 +27,8 @@ def parse_args(args):
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.extra_args = parse_args(extra_args)

def get_input_iter(self):
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
Expand Down
5 changes: 3 additions & 2 deletions torchbenchmark/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
gather + gemv is the primary kernel driving mixtral perf.
"""

import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -38,8 +39,8 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
* 1e-6
)

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)

@register_benchmark(baseline=True)
def test_0(self, p1, p2, p3) -> Callable:
Expand Down
5 changes: 3 additions & 2 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -79,8 +80,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"]
DEFAULT_PRECISION = "fp16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
gemm_args = parse_args(self.extra_args)
if gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import triton.ops
import triton.language as tl

from typing import Any
from typing import Any, Optional, List

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand All @@ -27,8 +27,8 @@
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
# `Group size` and `inner K tiles` are defaults from gpt-fast.
self.group_size = 32
self.inner_k_tiles = 8
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class Operator(BenchmarkOperator):
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
) # bias towards larger sizes, which are more representative of real-world shapes
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class Operator(BenchmarkOperator):
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
) # bias towards larger sizes, which are more representative of real-world shapes
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
args = parse_op_args(self.extra_args)
self.input_dim = args.input_dim
self.reduce_dim = args.reduce_dim
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/template_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -29,8 +29,8 @@
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.shapes = BUILDIN_SHAPES

@register_benchmark(baseline=True)
Expand Down
5 changes: 3 additions & 2 deletions torchbenchmark/operators/test_op/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
from typing import Generator, List, Optional

import torch
Expand All @@ -14,8 +15,8 @@ class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["test_metric"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)

@register_benchmark(label="new_op_label")
def test_op(self, x: torch.Tensor):
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/welford/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -38,8 +38,8 @@
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.shapes = BUILDIN_SHAPES

@register_benchmark()
Expand Down
86 changes: 13 additions & 73 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

_RANGE_NAME = "tritonbench_range"

class Mode(Enum):
FWD = "fwd"
Expand Down Expand Up @@ -388,7 +388,6 @@ def _inner(self, *args, **kwargs):

return decorator


def register_metric(
# Metrics that only apply to non-baseline impls
# E.g., accuracy, speedup
Expand Down Expand Up @@ -416,100 +415,47 @@ def _inner(self, *args, **kwargs):

return decorator


def parse_args(
default_metrics: List[str],
args: List[str],
) -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--metrics",
default=",".join(default_metrics),
help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
)
parser.add_argument(
"--only",
default=None,
help="Specify one or multiple operator implementations to run."
)
parser.add_argument(
"--baseline",
type=str,
default=None,
help="Override default baseline."
)
parser.add_argument(
"--num-inputs",
type=int,
help="Number of example inputs.",
)
parser.add_argument(
"--keep-going",
action="store_true",
)
parser.add_argument(
"--input-id",
type=int,
default=0,
help="Specify the start input id to run. " \
"For example, --input-id 0 runs only the first available input sample." \
"When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> " \
"and run <Y> different inputs."
)
parser.add_argument(
"--test-only",
action="store_true",
help="Run this under test mode, potentially skipping expensive steps like autotuning."
)
parser.add_argument(
"--dump-ir",
action="store_true",
help="Dump Triton IR",
)
return parser.parse_known_args(args)

class PostInitProcessor(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post__init__()
return obj


_RANGE_NAME = "tritonbench_range"


class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
test: str = "eval"
device: str = "cuda"
# By default, only collect latency metrics
# Each operator can override to define their own default metrics
DEFAULT_METRICS = ["latency"]
required_metrics: List[str] = DEFAULT_METRICS
_input_iter: Optional[Generator] = None
extra_args: List[str] = []
example_inputs: Any = None
use_cuda_graphs: bool = True

# By default, only collect latency metrics
# Each operator can override to define their own default metrics
DEFAULT_METRICS = ["latency"]

"""
A base class for adding operators to torch benchmark.
"""

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=None):
set_random_seed()
self.name = _find_op_name_from_module_path(self.__class__.__module__)
self._raw_extra_args = copy.deepcopy(extra_args)
self.tb_args = tb_args
# we accept both "fwd" and "eval"
if mode == "fwd":
if self.tb_args.mode == "fwd":
self.mode = Mode.FWD
elif mode == "fwd_bwd":
elif self.tb_args.mode == "fwd_bwd":
self.mode = Mode.FWD_BWD
else:
assert (
mode == "bwd"
self.tb_args.mode == "bwd"
), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
self.mode = Mode.BWD
self.dargs, unprocessed_args = parse_decoration_args(self, extra_args)
self.device = tb_args.device
self.required_metrics = list(set(tb_args.metrics.split(","))) if tb_args.metrics else self.required_metrics
self.dargs, self.extra_args = parse_decoration_args(self, extra_args)
if self.name not in REGISTERED_X_VALS:
REGISTERED_X_VALS[self.name] = "x_val"
# This will be changed by the time we apply the decoration args
Expand All @@ -518,17 +464,11 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
[x for x in REGISTERED_METRICS.get(self.name, []) if x not in BUILTIN_METRICS]
)
self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
self.tb_args, self.extra_args = parse_args(
self.DEFAULT_METRICS,
unprocessed_args
)
if self.tb_args.baseline:
BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline
self.required_metrics = list(set(self.tb_args.metrics.split(",")))
self._only = _split_params_by_comma(self.tb_args.only)
self._input_id = self.tb_args.input_id
self._num_inputs = self.tb_args.num_inputs
self.device = device

# Run the post initialization
def __post__init__(self):
Expand Down
Loading
Loading