diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 747f23641c358..ef0cde55739d4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1346,6 +1346,92 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): ) +def test_sum_kernel_ttgir(): + ir = """ +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<128> : tensor<128x1xi32, #blocked> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<128x1xi32, #blocked> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %5 = tt.broadcast %2 : (tensor<128x1xi32, #blocked>) -> tensor<128x128xi32, #blocked> + %6 = tt.broadcast %4 : (tensor<1x128xi32, #blocked>) -> tensor<128x128xi32, #blocked> + %7 = arith.addi %5, %6 : tensor<128x128xi32, #blocked> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xi32, #blocked> + %11 = "tt.reduce"(%10) ({ + ^bb0(%arg2: i32, %arg3: i32): + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32 + }) {axis = 1 : i32} : (tensor<128x128xi32, #blocked>) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %12 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked> + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %14 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked> + %18 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked> + %20 = tt.addptr %18, %14 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + tt.store %20, %12 {cache = 1 : i32, evict = 1 : i32} : tensor<128x1xi32, #blocked> + tt.return + } +} + """ + + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + BLOCK_SIZE = 128 + x = np.ones((BLOCK_SIZE, BLOCK_SIZE), dtype=np.int32) + y = np.zeros((BLOCK_SIZE, 1), dtype=np.int32) + x_tri = torch.tensor(x, device='cuda') + y_tri = torch.tensor(y, device='cuda') + + kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = np.sum(x, axis=1, keepdims=True) + + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +def test_sum_kernel(): + + @triton.jit + def sum_kernel( + x_ptr, + y_ptr, + STRIDE: tl.constexpr, + BLOCK_SIZE: tl.constexpr + ): + arange = tl.arange(0, BLOCK_SIZE) + off = arange[:, None] * STRIDE + arange[None, :] + x_val = tl.load(x_ptr + off) + + # This 2D reduction doesn't work + x_sum = tl.sum(x_val, axis=1) + x_sum = tl.sum(x_sum, axis=0) + + # This 1D reduction works + #x_sum = tl.sum(tl.view(x_val, (BLOCK_SIZE*BLOCK_SIZE,)), axis=0) + + tl.store(y_ptr, x_sum) + + BLOCK_SIZE = 128 + x = np.ones((BLOCK_SIZE, BLOCK_SIZE), dtype=np.long) + y = np.zeros((1,), dtype=np.long) + x_tri = torch.tensor(x, device='cuda') + y_tri = torch.tensor(y, device='cuda') + + sum_kernel[(1,)](x_tri, y_tri, x_tri.stride(0), BLOCK_SIZE) + + y_ref = np.sum(x) + + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + def test_generic_reduction(device='cuda'): @triton.jit