Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TESTING] clean up testing.do_bench #1513

Merged
merged 8 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions python/test/regression/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def nvsmi(attrs):
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.16, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.30, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(1024, 64, 1024): {'float16': 0.037, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.25, 'float32': 0.230, 'int8': 0.177},
}
Expand All @@ -94,10 +94,10 @@ def test_matmul(M, N, K, dtype_str):
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
ms = triton.testing.do_bench(fn, warmup=100, rep=300)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)


#######################
Expand Down Expand Up @@ -131,8 +131,8 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 256: 0.132,
1024 * 1024: 0.352,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
Expand All @@ -150,10 +150,10 @@ def test_elementwise(N):
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

#######################
# Flash-Attention
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
# compute flops
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul
Expand All @@ -201,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def kernel_call():
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return (float('inf'), float('inf'), float('inf'))

Expand Down
13 changes: 6 additions & 7 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def nvsmi(attrs):


def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
percentiles=(0.5, 0.2, 0.8),
quantiles=None,
fast_flush=False,
return_mode="min"):
assert return_mode in ["min", "max", "mean", "median"]
Expand All @@ -35,8 +35,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
Expand Down Expand Up @@ -84,10 +84,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
if percentiles is not None:
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
return tuple(percentiles)
return getattr(torch, return_mode)(times.item())
if quantiles is not None:
return torch.quantile(times, torch.tensor(quantiles)).tolist()
return getattr(torch, return_mode)(times).item()


def assert_close(x, y, atol=None, rtol=None, err_msg=''):
Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def add(x: torch.Tensor, y: torch.Tensor):
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
7 changes: 4 additions & 3 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def softmax(x):
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ def matmul(a, b, activation=None):
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/05-layer-norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
quantiles = [0.5, 0.2, 0.8]
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
Expand All @@ -350,13 +351,13 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
quantiles=quantiles, grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)


Expand Down
4 changes: 2 additions & 2 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
Expand All @@ -353,7 +353,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


Expand Down