Skip to content

Commit

Permalink
[TESTING] clean up testing.do_bench (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Apr 12, 2023
1 parent 081f640 commit 02e3c18
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 28 deletions.
18 changes: 9 additions & 9 deletions python/test/regression/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def nvsmi(attrs):
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.16, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.30, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(1024, 64, 1024): {'float16': 0.037, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.25, 'float32': 0.230, 'int8': 0.177},
}
Expand All @@ -94,10 +94,10 @@ def test_matmul(M, N, K, dtype_str):
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
ms = triton.testing.do_bench(fn, warmup=100, rep=300)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)


#######################
Expand Down Expand Up @@ -131,8 +131,8 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 256: 0.132,
1024 * 1024: 0.352,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
Expand All @@ -150,10 +150,10 @@ def test_elementwise(N):
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

#######################
# Flash-Attention
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
# compute flops
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul
Expand All @@ -201,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def kernel_call():
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return (float('inf'), float('inf'), float('inf'))

Expand Down
13 changes: 6 additions & 7 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def nvsmi(attrs):


def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
percentiles=(0.5, 0.2, 0.8),
quantiles=None,
fast_flush=False,
return_mode="min"):
assert return_mode in ["min", "max", "mean", "median"]
Expand All @@ -35,8 +35,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
Expand Down Expand Up @@ -84,10 +84,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
if percentiles is not None:
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
return tuple(percentiles)
return getattr(torch, return_mode)(times.item())
if quantiles is not None:
return torch.quantile(times, torch.tensor(quantiles)).tolist()
return getattr(torch, return_mode)(times).item()


def assert_close(x, y, atol=None, rtol=None, err_msg=''):
Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def add(x: torch.Tensor, y: torch.Tensor):
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
7 changes: 4 additions & 3 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def softmax(x):
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ def matmul(a, b, activation=None):
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/05-layer-norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
quantiles = [0.5, 0.2, 0.8]
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
Expand All @@ -350,13 +351,13 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
quantiles=quantiles, grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)


Expand Down
4 changes: 2 additions & 2 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
Expand All @@ -353,7 +353,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


Expand Down

0 comments on commit 02e3c18

Please sign in to comment.