From af9c15958fe34ed2fa21fa2b694f3bd131c5f982 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Wed, 3 Jul 2024 19:10:37 -0400
Subject: [PATCH 1/8] Add help for run_benchmark

---
 run_benchmark.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/run_benchmark.py b/run_benchmark.py
index 63bad54040..62c4894894 100644
--- a/run_benchmark.py
+++ b/run_benchmark.py
@@ -24,7 +24,7 @@ def list_benchmarks() -> Dict[str, str]:
 
 def run():
     available_benchmarks = list_benchmarks()
-    parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark")
+    parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark", add_help=False)
     parser.add_argument(
         "bm_name",
         choices=available_benchmarks.keys(),

From c45919bba74073b85c1d76bc38408e38ab8415a4 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Wed, 3 Jul 2024 19:31:53 -0400
Subject: [PATCH 2/8] Add help to runners

---
 torchbenchmark/util/triton_op.py | 10 +++++-----
 userbenchmark/triton/run.py      | 13 +++++++++----
 2 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py
index b92568ce22..8ffb60cdfc 100644
--- a/torchbenchmark/util/triton_op.py
+++ b/torchbenchmark/util/triton_op.py
@@ -417,11 +417,10 @@ def _inner(self, *args, **kwargs):
     return decorator
 
 
-def parse_args(
+def get_tbargs_parser(
     default_metrics: List[str],
-    args: List[str],
 ) -> Tuple[argparse.Namespace, List[str]]:
-    parser = argparse.ArgumentParser(allow_abbrev=False)
+    parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
     parser.add_argument(
         "--metrics",
         default=",".join(default_metrics),
@@ -466,7 +465,7 @@ def parse_args(
         action="store_true",
         help="Dump Triton IR",
     )
-    return parser.parse_known_args(args)
+    return parser
 
 class PostInitProcessor(type):
     def __call__(cls, *args, **kwargs):
@@ -518,7 +517,8 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
             [x for x in REGISTERED_METRICS.get(self.name, []) if x not in BUILTIN_METRICS]
         )
         self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
-        self.tb_args, self.extra_args = parse_args(
+        tb_parser = get_tbargs_parser()
+        self.tb_args, self.extra_args = tb_parser.parse_known_args(
             self.DEFAULT_METRICS,
             unprocessed_args
         )
diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py
index 0a4f60e87c..4906ba1cf0 100644
--- a/userbenchmark/triton/run.py
+++ b/userbenchmark/triton/run.py
@@ -7,7 +7,7 @@
 from torchbenchmark.operators import load_opbench_by_name
 
 from torchbenchmark.util.triton_op import (
-    BenchmarkOperatorResult,
+    get_tbargs_parser,
     DEFAULT_RUN_ITERS,
     DEFAULT_WARMUP,
 )
@@ -15,8 +15,8 @@
 TRITON_BENCH_CSV_DUMP_PATH = tempfile.gettempdir() + "/tritonbench/"
 
 
-def parse_args(args):
-    parser = argparse.ArgumentParser(allow_abbrev=False)
+def get_parser(args):
+    parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
     parser.add_argument("--op", type=str, default=None, help="Operator to benchmark.")
     parser.add_argument(
         "--mode",
@@ -84,7 +84,12 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
 def run(args: List[str] = []):
     if args == []:
         args = sys.argv[1:]
-    args, extra_args = parse_args(args)
+    parser = get_parser(args)
+    tb_parser = get_tbargs_parser(default_metrics=[])
+    args, extra_args = parser.parse_known_args(args)
+    if "--help" in extra_args or "-h" in extra_args:
+        parser.print_help()
+        tb_parser.print_help()
     if args.ci:
         from .ci import run_ci
         run_ci()

From 9c4ba0123c50f7fa09319ac7d63ef5321a5c0220 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Wed, 3 Jul 2024 19:39:25 -0400
Subject: [PATCH 3/8] Add help message

---
 torchbenchmark/util/triton_op.py | 76 +++-----------------------
 userbenchmark/triton/run.py      | 93 ++++++++++++++++++++++++++------
 2 files changed, 85 insertions(+), 84 deletions(-)

diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py
index 8ffb60cdfc..df6e330ca5 100644
--- a/torchbenchmark/util/triton_op.py
+++ b/torchbenchmark/util/triton_op.py
@@ -63,7 +63,7 @@
     "fp16": torch.float16,
     "bf16": torch.bfloat16,
 }
-
+_RANGE_NAME = "tritonbench_range"
 
 class Mode(Enum):
     FWD = "fwd"
@@ -388,7 +388,6 @@ def _inner(self, *args, **kwargs):
 
     return decorator
 
-
 def register_metric(
     # Metrics that only apply to non-baseline impls
     # E.g., accuracy, speedup
@@ -416,67 +415,12 @@ def _inner(self, *args, **kwargs):
 
     return decorator
 
-
-def get_tbargs_parser(
-    default_metrics: List[str],
-) -> Tuple[argparse.Namespace, List[str]]:
-    parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
-    parser.add_argument(
-        "--metrics",
-        default=",".join(default_metrics),
-        help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
-    )
-    parser.add_argument(
-        "--only",
-        default=None,
-        help="Specify one or multiple operator implementations to run."
-    )
-    parser.add_argument(
-        "--baseline",
-        type=str,
-        default=None,
-        help="Override default baseline."
-    )
-    parser.add_argument(
-        "--num-inputs",
-        type=int,
-        help="Number of example inputs.",
-    )
-    parser.add_argument(
-        "--keep-going",
-        action="store_true",
-    )
-    parser.add_argument(
-        "--input-id",
-        type=int,
-        default=0,
-        help="Specify the start input id to run. " \
-            "For example, --input-id 0 runs only the first available input sample." \
-            "When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> " \
-            "and run <Y> different inputs."
-    )
-    parser.add_argument(
-        "--test-only",
-        action="store_true",
-        help="Run this under test mode, potentially skipping expensive steps like autotuning."
-    )
-    parser.add_argument(
-        "--dump-ir",
-        action="store_true",
-        help="Dump Triton IR",
-    )
-    return parser
-
 class PostInitProcessor(type):
     def __call__(cls, *args, **kwargs):
         obj = type.__call__(cls, *args, **kwargs)
         obj.__post__init__()
         return obj
 
-
-_RANGE_NAME = "tritonbench_range"
-
-
 class BenchmarkOperator(metaclass=PostInitProcessor):
     mode: Mode = Mode.FWD
     test: str = "eval"
@@ -494,21 +438,23 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
     A base class for adding operators to torch benchmark.
     """
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=None):
         set_random_seed()
         self.name = _find_op_name_from_module_path(self.__class__.__module__)
         self._raw_extra_args = copy.deepcopy(extra_args)
+        self.tb_args = tb_args
         # we accept both "fwd" and "eval"
-        if mode == "fwd":
+        if self.tb_args.mode == "fwd":
             self.mode = Mode.FWD
-        elif mode == "fwd_bwd":
+        elif self.tb_args.mode == "fwd_bwd":
             self.mode = Mode.FWD_BWD
         else:
             assert (
-                mode == "bwd"
+                self.tb_args.mode == "bwd"
             ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
             self.mode = Mode.BWD
-        self.dargs, unprocessed_args = parse_decoration_args(self, extra_args)
+        self.device = tb_args.device
+        self.dargs, self.extra_args = parse_decoration_args(self, extra_args)
         if self.name not in REGISTERED_X_VALS:
             REGISTERED_X_VALS[self.name] = "x_val"
         # This will be changed by the time we apply the decoration args
@@ -517,18 +463,12 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
             [x for x in REGISTERED_METRICS.get(self.name, []) if x not in BUILTIN_METRICS]
         )
         self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
-        tb_parser = get_tbargs_parser()
-        self.tb_args, self.extra_args = tb_parser.parse_known_args(
-            self.DEFAULT_METRICS,
-            unprocessed_args
-        )
         if self.tb_args.baseline:
             BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline
         self.required_metrics = list(set(self.tb_args.metrics.split(",")))
         self._only = _split_params_by_comma(self.tb_args.only)
         self._input_id = self.tb_args.input_id
         self._num_inputs = self.tb_args.num_inputs
-        self.device = device
 
     # Run the post initialization
     def __post__init__(self):
diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py
index 4906ba1cf0..ca2d421cfd 100644
--- a/userbenchmark/triton/run.py
+++ b/userbenchmark/triton/run.py
@@ -14,9 +14,8 @@
 
 TRITON_BENCH_CSV_DUMP_PATH = tempfile.gettempdir() + "/tritonbench/"
 
-
-def get_parser(args):
-    parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
+def get_parser():
+    parser = argparse.ArgumentParser(allow_abbrev=False)
     parser.add_argument("--op", type=str, default=None, help="Operator to benchmark.")
     parser.add_argument(
         "--mode",
@@ -35,16 +34,82 @@ def get_parser(args):
         help="Num of warmup runs for reach benchmark run.",
     )
     parser.add_argument(
-        "--iter", default=DEFAULT_RUN_ITERS, help="Num of reps for each benchmark run."
+        "--iter",
+        default=DEFAULT_RUN_ITERS,
+        help="Num of reps for each benchmark run.",
+    )
+    parser.add_argument(
+        "--csv",
+        action="store_true",
+        help="Print result as csv.",
+    )
+    parser.add_argument(
+        "--dump-csv",
+        action="store_true",
+        help="Dump result as csv.",
+    )
+    parser.add_argument(
+        "--skip-print",
+        action="store_true",
+        help="Skip printing result.",
+    )
+    parser.add_argument(
+        "--plot",
+        action="store_true",
+        help="Plot the result.",
+    )
+    parser.add_argument(
+        "--ci",
+        action="store_true",
+        help="Run in the CI mode.",
+    )
+    parser.add_argument(
+        "--metrics",
+        default=None,
+        help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
+    )
+    parser.add_argument(
+        "--only",
+        default=None,
+        help="Specify one or multiple operator implementations to run."
+    )
+    parser.add_argument(
+        "--baseline",
+        type=str,
+        default=None,
+        help="Override default baseline."
+    )
+    parser.add_argument(
+        "--num-inputs",
+        type=int,
+        help="Number of example inputs.",
+    )
+    parser.add_argument(
+        "--keep-going",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--input-id",
+        type=int,
+        default=0,
+        help="Specify the start input id to run. " \
+            "For example, --input-id 0 runs only the first available input sample." \
+            "When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> " \
+            "and run <Y> different inputs."
+    )
+    parser.add_argument(
+        "--test-only",
+        action="store_true",
+        help="Run this under test mode, potentially skipping expensive steps like autotuning."
+    )
+    parser.add_argument(
+        "--dump-ir",
+        action="store_true",
+        help="Dump Triton IR",
     )
-    parser.add_argument("--csv", action="store_true", help="Print result as csv.")
-    parser.add_argument("--dump-csv", action="store_true", help="Dump result as csv.")
-    parser.add_argument("--skip-print", action="store_true", help="Skip printing result.")
-    parser.add_argument("--plot", action="store_true", help="Plot the result.")
     if not hasattr(torch_version, "git_version"):
         parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
-    parser.add_argument("--ci", action="store_true", help="Run in the CI mode.")
-    return parser.parse_known_args(args)
+    return parser
 
 def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
     Opbench = load_opbench_by_name(args.op)
@@ -55,6 +120,7 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
     opbench = Opbench(
         mode=args.mode,
         device=args.device,
+        tb_args=args,
         extra_args=extra_args,
     )
     try:
@@ -68,7 +134,6 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
                 print(metrics)
         if not hasattr(torch_version, "git_version") and args.log_scuba:
             from userbenchmark.triton.fb import log_benchmark
-
             log_benchmark(metrics)
         if args.plot:
             try:
@@ -84,12 +149,8 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
 def run(args: List[str] = []):
     if args == []:
         args = sys.argv[1:]
-    parser = get_parser(args)
-    tb_parser = get_tbargs_parser(default_metrics=[])
+    parser = get_parser()
     args, extra_args = parser.parse_known_args(args)
-    if "--help" in extra_args or "-h" in extra_args:
-        parser.print_help()
-        tb_parser.print_help()
     if args.ci:
         from .ci import run_ci
         run_ci()

From 0a51ebb3e52b1a6349dfe113ce1eaf6a559c9b43 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Wed, 3 Jul 2024 19:40:07 -0400
Subject: [PATCH 4/8] Bugfix

---
 userbenchmark/triton/run.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py
index ca2d421cfd..84b80f544f 100644
--- a/userbenchmark/triton/run.py
+++ b/userbenchmark/triton/run.py
@@ -7,7 +7,6 @@
 from torchbenchmark.operators import load_opbench_by_name
 
 from torchbenchmark.util.triton_op import (
-    get_tbargs_parser,
     DEFAULT_RUN_ITERS,
     DEFAULT_WARMUP,
 )

From 080bd40b9fc9bdd73a5fce1d9fffa508c8eec52b Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Wed, 3 Jul 2024 19:42:50 -0400
Subject: [PATCH 5/8] Bugfix

---
 userbenchmark/triton/run.py | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py
index 84b80f544f..45e07ac013 100644
--- a/userbenchmark/triton/run.py
+++ b/userbenchmark/triton/run.py
@@ -15,18 +15,33 @@
 
 def get_parser():
     parser = argparse.ArgumentParser(allow_abbrev=False)
-    parser.add_argument("--op", type=str, default=None, help="Operator to benchmark.")
+    parser.add_argument(
+        "--op",
+        type=str,
+        required=True,
+        help="Operator to benchmark."
+    )
     parser.add_argument(
         "--mode",
         choices=["fwd", "bwd", "fwd_bwd"],
         default="fwd",
         help="Test mode (fwd, bwd, or fwd_bwd).",
     )
-    parser.add_argument("--bwd", action="store_true", help="Run backward pass.")
     parser.add_argument(
-        "--fwd_bwd", action="store_true", help="Run both forward and backward pass."
+        "--bwd",
+        action="store_true",
+        help="Run backward pass."
+    )
+    parser.add_argument(
+        "--fwd_bwd",
+        action="store_true",
+        help="Run both forward and backward pass.",
+    )
+    parser.add_argument(
+        "--device",
+        default="cuda",
+        help="Device to benchmark.",
     )
-    parser.add_argument("--device", default="cuda", help="Device to benchmark.")
     parser.add_argument(
         "--warmup",
         default=DEFAULT_WARMUP,
@@ -117,8 +132,6 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> None:
     if args.bwd:
         args.mode = "bwd"
     opbench = Opbench(
-        mode=args.mode,
-        device=args.device,
         tb_args=args,
         extra_args=extra_args,
     )

From 485fe2286427e458213aafe77e185209c8127554 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Thu, 4 Jul 2024 10:08:27 -0400
Subject: [PATCH 6/8] Change the interface of all oss operators

---
 torchbenchmark/operators/addmm/operator.py              | 7 +++----
 torchbenchmark/operators/flash_attention/operator.py    | 5 ++---
 torchbenchmark/operators/fp8_gemm/fp8_gemm.py           | 6 +++---
 torchbenchmark/operators/fp8_gemm_blockwise/operator.py | 4 ++--
 torchbenchmark/operators/fp8_gemm_rowwise/operator.py   | 4 ++--
 torchbenchmark/operators/gather_gemv/operator.py        | 5 +++--
 torchbenchmark/operators/gemm/operator.py               | 5 +++--
 torchbenchmark/operators/int4_gemm/int4_gemm.py         | 6 +++---
 torchbenchmark/operators/jagged_mean/operator.py        | 4 ++--
 torchbenchmark/operators/jagged_sum/operator.py         | 4 ++--
 torchbenchmark/operators/sum/operator.py                | 4 ++--
 torchbenchmark/operators/template_attention/operator.py | 6 +++---
 torchbenchmark/operators/test_op/operator.py            | 5 +++--
 torchbenchmark/operators/welford/operator.py            | 6 +++---
 14 files changed, 36 insertions(+), 35 deletions(-)

diff --git a/torchbenchmark/operators/addmm/operator.py b/torchbenchmark/operators/addmm/operator.py
index 97264dd713..a936de6f25 100644
--- a/torchbenchmark/operators/addmm/operator.py
+++ b/torchbenchmark/operators/addmm/operator.py
@@ -1,6 +1,5 @@
-import csv
 import os
-import statistics
+import argparse
 from typing import Any, Callable, Generator, List, Optional, Tuple
 
 import numpy
@@ -71,8 +70,8 @@ class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["tflops"]
     DEFAULT_PRECISION = "bf16"
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         addmm_args = parse_args(self.extra_args)
         if addmm_args.m and addmm_args.n and addmm_args.k:
             self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]
diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py
index 0d017ec786..33299564ce 100644
--- a/torchbenchmark/operators/flash_attention/operator.py
+++ b/torchbenchmark/operators/flash_attention/operator.py
@@ -107,9 +107,8 @@ def parse_op_args(args: List[str]):
 class Operator(BenchmarkOperator):
     DEFAULT_PRECISION = "bf16"
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
-        # pass the framework level args (e.g., device, is_training, dtype) to the parent class
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         args = parse_op_args(self.extra_args)
         self.BATCH = args.batch
         self.H = args.n_heads
diff --git a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py
index 5fc29e657b..f25fb31490 100644
--- a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py
+++ b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py
@@ -7,7 +7,7 @@
 
 from triton.runtime.jit import reinterpret
 
-from typing import Any
+from typing import Any, Optional, List
 
 from torchbenchmark.util.triton_op import (
     BenchmarkOperator,
@@ -27,8 +27,8 @@ def parse_args(args):
 class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["tflops", "gbps", "latency"]
 
-    def __init__(self, mode, device, extra_args):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         self.extra_args = parse_args(extra_args)
 
     def get_input_iter(self):
diff --git a/torchbenchmark/operators/fp8_gemm_blockwise/operator.py b/torchbenchmark/operators/fp8_gemm_blockwise/operator.py
index 75a7da2861..a1a46219d1 100644
--- a/torchbenchmark/operators/fp8_gemm_blockwise/operator.py
+++ b/torchbenchmark/operators/fp8_gemm_blockwise/operator.py
@@ -108,8 +108,8 @@ class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
     DEFAULT_PRECISION = "fp32"
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         addmm_args = parse_args(self.extra_args)
         if addmm_args.m and addmm_args.n and addmm_args.k:
             self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
diff --git a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py
index 9f7ed10258..d0bba2632e 100644
--- a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py
+++ b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py
@@ -85,8 +85,8 @@ class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
     DEFAULT_PRECISION = "fp32"
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         addmm_args = parse_args(self.extra_args)
         if addmm_args.m and addmm_args.n and addmm_args.k:
             self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
diff --git a/torchbenchmark/operators/gather_gemv/operator.py b/torchbenchmark/operators/gather_gemv/operator.py
index f665e86a6e..3a8d145041 100644
--- a/torchbenchmark/operators/gather_gemv/operator.py
+++ b/torchbenchmark/operators/gather_gemv/operator.py
@@ -4,6 +4,7 @@
 gather + gemv is the primary kernel driving mixtral perf.
 """
 
+import argparse
 import csv
 import os
 import statistics
@@ -38,8 +39,8 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
             * 1e-6
         )
 
-    def __init__(self, mode: str, device: str, extra_args: List[str] = []):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
 
     @register_benchmark(baseline=True)
     def test_0(self, p1, p2, p3) -> Callable:
diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py
index 7ce0f6818b..c37313aeca 100644
--- a/torchbenchmark/operators/gemm/operator.py
+++ b/torchbenchmark/operators/gemm/operator.py
@@ -1,3 +1,4 @@
+import argparse
 import csv
 import os
 import statistics
@@ -79,8 +80,8 @@ class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"]
     DEFAULT_PRECISION = "fp16"
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         gemm_args = parse_args(self.extra_args)
         if gemm_args.input:
             self.shapes = read_shapes_from_csv(gemm_args.input)
diff --git a/torchbenchmark/operators/int4_gemm/int4_gemm.py b/torchbenchmark/operators/int4_gemm/int4_gemm.py
index 9b4b2c925d..6bf1621a64 100644
--- a/torchbenchmark/operators/int4_gemm/int4_gemm.py
+++ b/torchbenchmark/operators/int4_gemm/int4_gemm.py
@@ -12,7 +12,7 @@
 import triton.ops
 import triton.language as tl
 
-from typing import Any
+from typing import Any, Optional, List
 
 from torchbenchmark.util.triton_op import (
     BenchmarkOperator,
@@ -27,8 +27,8 @@
 class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["tflops", "gbps", "latency"]
 
-    def __init__(self, mode, device, extra_args):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         # `Group size` and `inner K tiles` are defaults from gpt-fast.
         self.group_size = 32
         self.inner_k_tiles = 8
diff --git a/torchbenchmark/operators/jagged_mean/operator.py b/torchbenchmark/operators/jagged_mean/operator.py
index 75856d94ec..374d1fe525 100644
--- a/torchbenchmark/operators/jagged_mean/operator.py
+++ b/torchbenchmark/operators/jagged_mean/operator.py
@@ -101,8 +101,8 @@ class Operator(BenchmarkOperator):
         False  # enables GPU/CPU sync (for methods like NestedTensor unbind)
     )
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         self.sizes = list(range(2, 12, 4)) + list(
             range(12, 23, 3)
         )  # bias towards larger sizes, which are more representative of real-world shapes
diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py
index f4cc39719f..300993729d 100644
--- a/torchbenchmark/operators/jagged_sum/operator.py
+++ b/torchbenchmark/operators/jagged_sum/operator.py
@@ -126,8 +126,8 @@ class Operator(BenchmarkOperator):
         False  # enables GPU/CPU sync (for methods like NestedTensor unbind)
     )
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         self.sizes = list(range(2, 12, 4)) + list(
             range(12, 23, 3)
         )  # bias towards larger sizes, which are more representative of real-world shapes
diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py
index 26c8bbdaf2..021ce76497 100644
--- a/torchbenchmark/operators/sum/operator.py
+++ b/torchbenchmark/operators/sum/operator.py
@@ -152,8 +152,8 @@ class Operator(BenchmarkOperator):
 
     DEFAULT_METRICS = ["latency", "accuracy"]
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         args = parse_op_args(self.extra_args)
         self.input_dim = args.input_dim
         self.reduce_dim = args.reduce_dim
diff --git a/torchbenchmark/operators/template_attention/operator.py b/torchbenchmark/operators/template_attention/operator.py
index bc08dd8915..6ce1e1b891 100644
--- a/torchbenchmark/operators/template_attention/operator.py
+++ b/torchbenchmark/operators/template_attention/operator.py
@@ -1,4 +1,4 @@
-
+import argparse
 import csv
 import os
 import statistics
@@ -29,8 +29,8 @@
 class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
 
-    def __init__(self, mode: str, device: str, extra_args: List[str] = []):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         self.shapes = BUILDIN_SHAPES
 
     @register_benchmark(baseline=True)
diff --git a/torchbenchmark/operators/test_op/operator.py b/torchbenchmark/operators/test_op/operator.py
index bf149aaeb8..db41438a42 100644
--- a/torchbenchmark/operators/test_op/operator.py
+++ b/torchbenchmark/operators/test_op/operator.py
@@ -1,3 +1,4 @@
+import argparse
 from typing import Generator, List, Optional
 
 import torch
@@ -14,8 +15,8 @@ class Operator(BenchmarkOperator):
 
     DEFAULT_METRICS = ["test_metric"]
 
-    def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
 
     @register_benchmark(label="new_op_label")
     def test_op(self, x: torch.Tensor):
diff --git a/torchbenchmark/operators/welford/operator.py b/torchbenchmark/operators/welford/operator.py
index d4877fe129..d5f2906ef5 100644
--- a/torchbenchmark/operators/welford/operator.py
+++ b/torchbenchmark/operators/welford/operator.py
@@ -1,4 +1,4 @@
-
+import argparse
 import csv
 import os
 import statistics
@@ -38,8 +38,8 @@
 class Operator(BenchmarkOperator):
     DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
 
-    def __init__(self, mode: str, device: str, extra_args: List[str] = []):
-        super().__init__(mode=mode, device=device, extra_args=extra_args)
+    def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
+        super().__init__(tb_args, extra_args)
         self.shapes = BUILDIN_SHAPES
 
     @register_benchmark()

From fb473ecdc394f370a5aa812abbb7e8f6c6ec3286 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Thu, 4 Jul 2024 10:10:32 -0400
Subject: [PATCH 7/8] Fix a bug

---
 torchbenchmark/util/triton_op.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py
index df6e330ca5..9632965211 100644
--- a/torchbenchmark/util/triton_op.py
+++ b/torchbenchmark/util/triton_op.py
@@ -454,6 +454,7 @@ def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=
             ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
             self.mode = Mode.BWD
         self.device = tb_args.device
+        self.metrics = tb_args.metrics if tb_args.metrics else self.DEFAULT_METRICS
         self.dargs, self.extra_args = parse_decoration_args(self, extra_args)
         if self.name not in REGISTERED_X_VALS:
             REGISTERED_X_VALS[self.name] = "x_val"

From dd3f48726e8e8f82258e91098a2e0c1568647661 Mon Sep 17 00:00:00 2001
From: Xu Zhao <xzhao9@meta.com>
Date: Thu, 4 Jul 2024 10:13:33 -0400
Subject: [PATCH 8/8] Fix required metrics

---
 torchbenchmark/util/triton_op.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py
index 9632965211..5cb0bb45dd 100644
--- a/torchbenchmark/util/triton_op.py
+++ b/torchbenchmark/util/triton_op.py
@@ -425,15 +425,15 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
     mode: Mode = Mode.FWD
     test: str = "eval"
     device: str = "cuda"
+    # By default, only collect latency metrics
+    # Each operator can override to define their own default metrics
+    DEFAULT_METRICS = ["latency"]
+    required_metrics: List[str] = DEFAULT_METRICS
     _input_iter: Optional[Generator] = None
     extra_args: List[str] = []
     example_inputs: Any = None
     use_cuda_graphs: bool = True
 
-    # By default, only collect latency metrics
-    # Each operator can override to define their own default metrics
-    DEFAULT_METRICS = ["latency"]
-
     """
     A base class for adding operators to torch benchmark.
     """
@@ -454,7 +454,7 @@ def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=
             ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
             self.mode = Mode.BWD
         self.device = tb_args.device
-        self.metrics = tb_args.metrics if tb_args.metrics else self.DEFAULT_METRICS
+        self.required_metrics = list(set(tb_args.metrics.split(","))) if tb_args.metrics else self.required_metrics
         self.dargs, self.extra_args = parse_decoration_args(self, extra_args)
         if self.name not in REGISTERED_X_VALS:
             REGISTERED_X_VALS[self.name] = "x_val"
@@ -466,7 +466,6 @@ def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=
         self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
         if self.tb_args.baseline:
             BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline
-        self.required_metrics = list(set(self.tb_args.metrics.split(",")))
         self._only = _split_params_by_comma(self.tb_args.only)
         self._input_id = self.tb_args.input_id
         self._num_inputs = self.tb_args.num_inputs