Skip to content

Commit

Permalink
Unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
zahimoud committed Apr 29, 2023
1 parent bd36694 commit 7ded3c8
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,92 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)


def test_sum_kernel_ttgir():
ir = """
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #blocked>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked>
%2 = arith.muli %1, %cst : tensor<128x1xi32, #blocked>
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked>
%5 = tt.broadcast %2 : (tensor<128x1xi32, #blocked>) -> tensor<128x128xi32, #blocked>
%6 = tt.broadcast %4 : (tensor<1x128xi32, #blocked>) -> tensor<128x128xi32, #blocked>
%7 = arith.addi %5, %6 : tensor<128x128xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<128x128x!tt.ptr<i32>, #blocked>
%9 = tt.addptr %8, %7 : tensor<128x128x!tt.ptr<i32>, #blocked>, tensor<128x128xi32, #blocked>
%10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xi32, #blocked>
%11 = "tt.reduce"(%10) ({
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}) {axis = 1 : i32} : (tensor<128x128xi32, #blocked>) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%12 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked>
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%14 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked>
%18 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<128x1x!tt.ptr<i32>, #blocked>
%20 = tt.addptr %18, %14 : tensor<128x1x!tt.ptr<i32>, #blocked>, tensor<128x1xi32, #blocked>
tt.store %20, %12 {cache = 1 : i32, evict = 1 : i32} : tensor<128x1xi32, #blocked>
tt.return
}
}
"""

import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

BLOCK_SIZE = 128
x = np.ones((BLOCK_SIZE, BLOCK_SIZE), dtype=np.int32)
y = np.zeros((BLOCK_SIZE, 1), dtype=np.int32)
x_tri = torch.tensor(x, device='cuda')
y_tri = torch.tensor(y, device='cuda')

kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = np.sum(x, axis=1, keepdims=True)

np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


def test_sum_kernel():

@triton.jit
def sum_kernel(
x_ptr,
y_ptr,
STRIDE: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
arange = tl.arange(0, BLOCK_SIZE)
off = arange[:, None] * STRIDE + arange[None, :]
x_val = tl.load(x_ptr + off)

# This 2D reduction doesn't work
x_sum = tl.sum(x_val, axis=1)
x_sum = tl.sum(x_sum, axis=0)

# This 1D reduction works
#x_sum = tl.sum(tl.view(x_val, (BLOCK_SIZE*BLOCK_SIZE,)), axis=0)

tl.store(y_ptr, x_sum)

BLOCK_SIZE = 128
x = np.ones((BLOCK_SIZE, BLOCK_SIZE), dtype=np.long)
y = np.zeros((1,), dtype=np.long)
x_tri = torch.tensor(x, device='cuda')
y_tri = torch.tensor(y, device='cuda')

sum_kernel[(1,)](x_tri, y_tri, x_tri.stride(0), BLOCK_SIZE)

y_ref = np.sum(x)

np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


def test_generic_reduction(device='cuda'):

@triton.jit
Expand Down

0 comments on commit 7ded3c8

Please sign in to comment.