From 3449a9d40d72beb489b9b13b1324a1e61fe13ba7 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Mon, 1 May 2023 18:00:23 -0700 Subject: [PATCH] Zahi/slice reduce rebased (#1594) [BACKEND] Enable slice layout support for reduce op --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 27 +++++++ lib/Analysis/Utility.cpp | 15 +++- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 19 +++-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 78 +++++++++++++++++-- python/test/unit/language/test_core.py | 64 +++++++++++++++ 5 files changed, 187 insertions(+), 16 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 4658fe59ce0d..c12795ffcdd3 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -31,10 +31,37 @@ SmallVector getWarpsPerCTA(Attribute layout); SmallVector getSizePerThread(Attribute layout); +// Returns the number of contiguous elements that each thread +// has access to, on each dimension of the tensor. E.g. +// for a blocked layout with sizePerThread = [1, 4], returns [1, 4], +// regardless of the shape of the tensor. SmallVector getContigPerThread(Attribute layout); +// Returns the number of non-replicated contiguous elements that each thread +// has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. SmallVector getUniqueContigPerThread(Type type); +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); + SmallVector getThreadsPerCTA(Attribute layout); SmallVector diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index dfd6302e74fc..4167ccb3057b 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -17,19 +17,23 @@ unsigned ReduceOpHelper::getInterWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, - triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); + triton::gpu::getWarpsPerCTAWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, - triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); + triton::gpu::getThreadsPerWarpWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = getSrcLayout(); - return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * - triton::gpu::getWarpsPerCTA(srcLayout)[axis]; + auto srcShape = getSrcShape(); + return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, + srcShape)[axis] * + triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } SmallVector ReduceOpHelper::getScratchConfigBasic() { @@ -88,6 +92,9 @@ bool ReduceOpHelper::isSupportedLayout() { return true; } } + if (auto sliceLayout = srcLayout.dyn_cast()) { + return true; + } return false; } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 6b4058149778..85b8de9fc36e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -87,6 +87,15 @@ struct ReduceOpConversion Attribute layout, SmallVector &index, SmallVector &writeIdx, std::map &ints, unsigned axis) const { + if (auto sliceLayout = layout.dyn_cast()) { + auto dim = sliceLayout.getDim(); + assert(dim != axis && "Reduction axis cannot be sliced"); + auto parentLayout = sliceLayout.getParent(); + getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints, + axis); + return; + } + writeIdx = index; auto sizePerThread = triton::gpu::getSizePerThread(layout); Value axisSizePerThread = ints[sizePerThread[axis]]; @@ -100,9 +109,10 @@ struct ReduceOpConversion // to map every `axisSizePerThread` to 1 value in smem as: // writeIdx[axis] = index[axis] / axisSizePerThread writeIdx[axis] = udiv(index[axis], axisSizePerThread); - } - auto mmaLayout = layout.dyn_cast(); - if (mmaLayout && mmaLayout.isAmpere()) { + } else if (auto mmaLayout = layout.dyn_cast()) { + if (!mmaLayout.isAmpere()) { + llvm::report_fatal_error("Unsupported layout"); + } if (axis == 0) { // Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8 // rows in smem would correspond to a warp. The mapping @@ -113,8 +123,7 @@ struct ReduceOpConversion // Same as BlockedEncodingAttr case writeIdx[axis] = udiv(index[axis], axisSizePerThread); } - } - if (mmaLayout && !mmaLayout.isAmpere()) { + } else { llvm::report_fatal_error("Unsupported layout"); } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d81b934cb31c..11f317b42fbc 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -81,10 +81,41 @@ SmallVector getThreadsPerWarp(Attribute layout) { if (mmaLayout.isAmpere()) return {8, 4}; } + if (auto sliceLayout = layout.dyn_cast()) { + auto parent = sliceLayout.getParent(); + auto parentThreadsPerWarp = getThreadsPerWarp(parent); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) + threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()]; + return threadsPerWarp; + } assert(0 && "getThreadsPerWarp not implemented"); return {}; } +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = layout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentThreadsPerWarp = + getThreadsPerWarpWithUniqueData(parentLayout, parentShape); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + return threadsPerWarp; + } + auto threadsPerWarp = getThreadsPerWarp(layout); + assert(threadsPerWarp.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) { + threadsPerWarp[i] = std::min(threadsPerWarp[i], tensorShape[i]); + } + + return threadsPerWarp; +} + SmallVector getWarpsPerCTA(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getWarpsPerCTA().begin(), @@ -94,10 +125,43 @@ SmallVector getWarpsPerCTA(Attribute layout) { return SmallVector(mmaLayout.getWarpsPerCTA().begin(), mmaLayout.getWarpsPerCTA().end()); } + if (auto sliceLayout = layout.dyn_cast()) { + auto parent = sliceLayout.getParent(); + auto parentWarpsPerCTA = getWarpsPerCTA(parent); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) + warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()]; + return warpsPerCTA; + } assert(0 && "getWarpsPerCTA not implemented"); return {}; } +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { + if (auto sliceLayout = layout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentWarpsPerCTA = + getWarpsPerCTAWithUniqueData(parentLayout, parentShape); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + return warpsPerCTA; + } + auto warpsPerCTA = getWarpsPerCTA(layout); + assert(warpsPerCTA.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) { + auto sizePerWarp = + getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i]; + auto maxWarpsPerDim = ceil(tensorShape[i], sizePerWarp); + warpsPerCTA[i] = std::min(warpsPerCTA[i], maxWarpsPerDim); + } + + return warpsPerCTA; +} + SmallVector getSizePerThread(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getSizePerThread().begin(), @@ -189,7 +253,7 @@ SmallVector getThreadsPerCTA(Attribute layout) { threads.push_back(blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getWarpsPerCTA()[d]); } else if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.getVersionMajor() == 2) { + if (mmaLayout.isAmpere()) { threads = {8 * mmaLayout.getWarpsPerCTA()[0], 4 * mmaLayout.getWarpsPerCTA()[1]}; } else @@ -1074,9 +1138,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, return mlir::failure(); } auto newType = op->getResult(0).getType().cast(); - // Ensure that the new insert_slice op is placed in the same place as the - // old insert_slice op. Otherwise, the new insert_slice op may be placed - // after the async_wait op, which is not allowed. + // Ensure that the new insert_slice op is placed in the same place as + // the old insert_slice op. Otherwise, the new insert_slice op may be + // placed after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(insert_slice); auto newArg = rewriter.create( @@ -1104,9 +1168,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, auto resType = RankedTensorType::get( origResType.getShape(), origResType.getElementType(), extract_slice.getType().cast().getEncoding()); - // Ensure that the new extract_slice op is placed in the same place as the - // old extract_slice op. Otherwise, the new extract_slice op may be placed - // after the async_wait op, which is not allowed. + // Ensure that the new extract_slice op is placed in the same place as + // the old extract_slice op. Otherwise, the new extract_slice op may be + // placed after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(extract_slice); auto newArg = rewriter.create( diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d0f56d8c12aa..4ea448e450d8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1420,6 +1420,70 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): ) +layouts = [ + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), + BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]), + BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1]) +] + + +@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]]) +@pytest.mark.parametrize("src_layout", layouts) +def test_reduce_2d(M, N, src_layout, device='cuda'): + ir = f""" + #src = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32 + }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32 + }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32 + tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32 + tt.return + }} + }} + """ + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32') + + z = np.zeros((1,)).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.sum(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + def test_generic_reduction(device='cuda'): @triton.jit