Skip to content

Commit 9301125

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Reviewed By: NikhilAPatel Differential Revision: D83617233
1 parent 7c74b21 commit 9301125

File tree

2 files changed

+98
-48
lines changed

2 files changed

+98
-48
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
23
import logging
34

45
from typing import Any, Callable, List, Optional
@@ -7,6 +8,8 @@
78
import torch._inductor.config as inductor_config
89
import triton
910

11+
from torch._inductor.kernel.mm import scaling_pairs, ScalingType
12+
1013
from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
1114
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda
1215

@@ -46,7 +49,7 @@
4649
def parse_args(args):
4750
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4851
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
52+
parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise")
5053
parser.add_argument("--m", type=int)
5154
parser.add_argument("--k", type=int)
5255
parser.add_argument("--n", type=int)
@@ -65,6 +68,58 @@ def get_fp8_dtype():
6568
return torch.float8_e4m3fnuz
6669

6770

71+
def get_scaling_recipe(scaling_recipe: str) -> int:
72+
if scaling_recipe == "TensorWise":
73+
return ScalingType.TensorWise
74+
elif scaling_recipe == "RowWise":
75+
return ScalingType.RowWise
76+
else:
77+
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")
78+
79+
80+
def get_scale(
81+
x: torch.Tensor,
82+
scaling_recipe: ScalingType,
83+
transpose: bool = False,
84+
custom_scale: float = None,
85+
) -> (torch.Tensor, torch.Tensor):
86+
def _get_scale_per_tensor(
87+
x: torch.Tensor, custom_scale: float = None
88+
) -> (torch.Tensor, torch.Tensor):
89+
# For tensor-wise scaling, kernel requires a float32 scale tensor
90+
if custom_scale:
91+
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
92+
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
93+
x *= scale
94+
return x, scale.to(torch.float32)
95+
96+
def _get_scale_per_row(
97+
x: torch.Tensor, transpose: bool = False
98+
) -> (torch.Tensor, torch.Tensor):
99+
if transpose: # scale_b.shape should be [1, N]
100+
scale = (
101+
torch.finfo(torch.float8_e4m3fn).max
102+
/ x.abs().max(dim=0, keepdim=True).values
103+
)
104+
else: # scale_a.shape should be [M, 1]
105+
scale = (
106+
torch.finfo(torch.float8_e4m3fn).max
107+
/ x.abs().max(dim=1, keepdim=True).values
108+
)
109+
x = x.mul(scale)
110+
return x, scale.to(
111+
torch.float32
112+
) # For row-wise scaling, kernel requires a float32 scale tensor
113+
114+
match scaling_recipe:
115+
case ScalingType.TensorWise:
116+
return _get_scale_per_tensor(x, custom_scale=custom_scale)
117+
case ScalingType.RowWise:
118+
return _get_scale_per_row(x, transpose=transpose)
119+
case _:
120+
raise AssertionError(f"Unsupported scaling type {scaling_recipe}")
121+
122+
68123
class Operator(BenchmarkOperator):
69124
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
70125
DEFAULT_PRECISION = "fp8"
@@ -78,53 +133,39 @@ def __init__(
78133

79134
self.fp8_dtype = get_fp8_dtype()
80135

136+
scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",")
137+
if (scaling_recipe_a, scaling_recipe_b) not in [
138+
(a.name, b.name) for a, b in scaling_pairs
139+
]:
140+
raise ValueError(
141+
f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
142+
)
143+
self.scaling_recipe_a = get_scaling_recipe(scaling_recipe_a)
144+
self.scaling_recipe_b = get_scaling_recipe(scaling_recipe_b)
145+
81146
def _get_dtype(self):
82-
if self.extra_args.scaling_rowwise:
83-
return torch.bfloat16
84-
else:
147+
if (
148+
self.scaling_recipe_a == ScalingType.TensorWise
149+
and self.scaling_recipe_b == ScalingType.TensorWise
150+
):
85151
return torch.float16
152+
return torch.bfloat16
86153

87154
def get_input_iter(self):
88-
def _get_scale_per_tensor(
89-
x: torch.Tensor, custom_scale: float = None
90-
) -> torch.Tensor:
91-
# For tensor-wise scaling, kernel requires a float32 scale tensor
92-
if custom_scale:
93-
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
94-
scale = torch.finfo(self.fp8_dtype).max / x.abs().max()
95-
return scale.to(torch.float32)
96-
97-
def _get_scale_per_row(
98-
x: torch.Tensor, transpose: bool = False
99-
) -> torch.Tensor:
100-
if transpose: # scale_b.shape should be [1, N]
101-
scale = (
102-
torch.finfo(self.fp8_dtype).max
103-
/ x.abs().max(dim=0, keepdim=True).values
104-
)
105-
else: # scale_a.shape should be [M, 1]
106-
scale = (
107-
torch.finfo(self.fp8_dtype).max
108-
/ x.abs().max(dim=1, keepdim=True).values
109-
)
110-
return scale.to(
111-
torch.float32
112-
) # For row-wise scaling, kernel requires a float32 scale tensor
113-
114155
def args(m, n, k):
115156
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
116157
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
117158

118-
if self.extra_args.scaling_rowwise:
119-
scale_a = _get_scale_per_row(a)
120-
scale_b = _get_scale_per_row(b)
121-
else:
122-
scale_a = _get_scale_per_tensor(
123-
a, custom_scale=self.extra_args.per_tensor_scale_a
124-
)
125-
scale_b = _get_scale_per_tensor(
126-
b, custom_scale=self.extra_args.per_tensor_scale_b
127-
)
159+
a, scale_a = get_scale(
160+
a,
161+
self.scaling_recipe_a,
162+
custom_scale=self.extra_args.per_tensor_scale_a,
163+
)
164+
b, scale_b = get_scale(
165+
b,
166+
self.scaling_recipe_b,
167+
custom_scale=self.extra_args.per_tensor_scale_b,
168+
)
128169

129170
# Kernels expect dtype=float8_e4m3fn(uz)
130171
a = a.to(self.fp8_dtype)
@@ -198,13 +239,21 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
198239

199240
@register_benchmark(enabled=True)
200241
def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
242+
if self.scaling_recipe_a == self.scaling_recipe_b == ScalingType.TensorWise:
243+
scaling_recipe_int = 0
244+
elif self.scaling_recipe_a == self.scaling_recipe_b == ScalingType.RowWise:
245+
scaling_recipe_int = 1
246+
else:
247+
raise ValueError(
248+
f"Invalid scaling pair: {self.scaling_recipe_a}, {self.scaling_recipe_b} for blackwell_persistent_tma_fp8_gemm."
249+
)
201250
return lambda: blackwell_persistent_tma(
202251
a,
203252
b,
204253
scale_a,
205254
scale_b,
206255
self._get_dtype(),
207-
self.extra_args.scaling_rowwise,
256+
scaling_recipe_int,
208257
)
209258

210259
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from functools import lru_cache
2+
23
from typing import Optional
34

45
import torch
56
import triton
67
import triton.language as tl
78

9+
from torch._inductor.kernel.mm import ScalingType
10+
811
from tritonbench.utils.env_utils import is_cuda
912
from tritonbench.utils.triton_utils import has_experimental_descriptor
1013

@@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410413
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411414

412415

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
416+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416417
configs = matmul_configs_blackwell()
417418

418419
# Check constraints.
@@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471472
NUM_SMS=NUM_SMS, #
472473
num_stages=configs[shape_dtype]["num_stages"], #
473474
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
475+
SCALING_MODE=scaling_mode, #
475476
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476477
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477478
)
@@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel(
504505
GROUP_SIZE_M: tl.constexpr, #
505506
ACC_TYPE: tl.constexpr,
506507
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
508+
SCALING_MODE: tl.constexpr, #
508509
WARP_SPECIALIZE: tl.constexpr,
509510
EPILOGUE_SUBTILE: tl.constexpr,
510511
): #
@@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel(
538539
tile_id_c = start_pid - NUM_SMS
539540
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540541

541-
if SCALING_ROWWISE:
542+
if SCALING_MODE == ScalingType.RowWise:
542543
# For row-wise scaling, we'll use the pointers as-is
543544
scale_a = scale_a_ptr
544545
scale_b = scale_b_ptr
@@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel(
563564
b_block = b_desc.load([offs_bn, offs_k])
564565
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565566

566-
if SCALING_ROWWISE:
567+
if SCALING_MODE == ScalingType.RowWise:
567568
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568569
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569570

0 commit comments

Comments
 (0)