From 2b2032f25d8da6808a9af2d20d4f4671a883e78e Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 22 Nov 2024 13:41:38 +0000 Subject: [PATCH 1/6] Use LLs for Hopper whenever we wouldn't use ldmatrix The legacy path has some bugs for cases like `kWidth=1`. I'm starting to port Hopper to use LLs to try to isolate them. --- .gitignore | 3 ++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 4 +- lib/Analysis/Utility.cpp | 3 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 10 +++-- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 40 +++++++++++++------ python/test/unit/language/test_core.py | 3 ++ .../ConvertLayoutOpToLLVM.cpp | 3 -- 7 files changed, 44 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 64843d551d5f..da0d23eefc9c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,9 @@ python/triton/language/extra # Proton python/triton/profiler +# Pytest +pytest.ini + # Instrumentation python/triton/instrumentation diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 9af550aae9c4..fee5e0afe311 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false. int k = (needTrans) ? matShape[0] : matShape[2]; int vec = (order[0] == rank-1) ? k : m; int mmaStride = (order[0] == rank-1) ? m : k; - int maxPhase = mmaStride / perPhase; + int maxPhase = std::max(mmaStride / perPhase, 1); return get(context, vec, perPhase, maxPhase, order, CTALayout); } @@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false. int k = needTrans ? matShape[1] : matShape[2]; int vec = (order[0] == rank-1) ? n : k; int mmaStride = (order[0] == rank-1) ? k : n; - int maxPhase = mmaStride / perPhase; + int maxPhase = std::max(mmaStride / perPhase, 1); return get(context, vec, perPhase, maxPhase, order, CTALayout); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 5fd87e4c0169..d890674410e4 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -629,7 +629,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, dotOperandLayout.getOpIdx() == 0 && mmaLayout.getWarpsPerCTA()[1] == 1 && !cvtNeedsSharedMemory(parentTy, srcTy) && - (elementTypeSize == 16 || elementTypeSize == 8); + (elementTypeSize == 16 || elementTypeSize == 8) && + dotOperandLayout.getKWidth() == 32 / elementTypeSize; return ans; } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index f0026c199324..32a6ea34a019 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -384,10 +384,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (isa(parent) && useLegacyMMAConversion) { return false; } - if (auto nvidiaMma = dyn_cast(parent)) { - if (nvidiaMma.isAmpere()) { - return true; - } + if (isa(parent)) { + return true; } if (isa(parent)) { return true; @@ -408,6 +406,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) { return failure(); } + // FIXME [Dot LL] Remove this once we implement this trick in LLs + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { + return failure(); + } // The following check can be removed when generalized warp shuffle // conversions are ready: diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index b090670d955c..82448f1efbc6 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // FIXME [Dot LL] // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedDotOpLayout(RankedTensorType type) { - auto layout = type.getEncoding(); - auto bitwidth = type.getElementType().getIntOrFloatBitWidth(); - if (auto dot = dyn_cast(layout)) { + static bool isSupportedDotOpLayout(MemDescType srcTy, + RankedTensorType dstTy) { + auto srcLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + auto bitwidth = dstTy.getElementTypeBitWidth(); + auto rank = dstTy.getRank(); + if (auto dot = dyn_cast(dstLayout)) { + auto vecWidth = 32 / bitwidth; auto kWidth = dot.getKWidth(); - // Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy: - // - kWidth == 8 - // - kWidth == 4, bitwidth = 32 + auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2; if (auto mma = dyn_cast(dot.getParent())) { + auto needTrans = kOrder != srcLayout.getOrder()[0]; + auto canUseLdmatrix = + (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth); + if (mma.isHopper()) { + // I think we should be able to remove this condition, but it's here + // as the legacy ldmatrix path does not support it + canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32; + } + // If we remove this one, ldmatrix will IMA. It can probably be relaxed + // though + canUseLdmatrix &= + srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; + // To be removed in https://github.com/triton-lang/triton/pull/5154 bool legacyLoweringIsBuggy = - kWidth >= 8 || (kWidth == 4 && bitwidth == 32); - return legacyLoweringIsBuggy && mma.isAmpere(); + (kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere(); + return (mma.isHopper() && !canUseLdmatrix) || + (mma.isAmpere() && legacyLoweringIsBuggy); } if (isa(dot.getParent())) return true; @@ -162,12 +178,12 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); + auto srcLayout = cast(srcTy.getEncoding()); Attribute dstLayout = dstTy.getEncoding(); if (isa(srcLayout) && (isa(dstLayout) || - isSupportedDotOpLayout(dstTy))) { + isSupportedDotOpLayout(srcTy, dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -206,7 +222,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && + assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 23c598ee168c..3f5aeab8d101 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5250,6 +5250,9 @@ def kernel(Out): # TODO: backend should be tested separately layouts = [ + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1), BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1324511aeb89..763a38058c13 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -459,9 +459,6 @@ void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { // For now give ConvertLayoutOpConversion higher benefit, I can split before // merging - // - // TODO(jlebar): lowerDistributedToDistributed does not get hit in any - // testcases. Is this dead code? Does the benefit need to be increased? patterns.add(typeConverter, targetInfo, benefit); // Same default benefit patterns.add(typeConverter, benefit); From cbf00e2521e7d90f3d4ac036f0c50935b4c21fab Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 25 Nov 2024 12:35:45 +0000 Subject: [PATCH 2/6] Relax vec constraint to accomodate for mixed types matmul --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 7 +------ test/Conversion/tritongpu_to_llvm_hopper.mlir | 3 ++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index aee7da8a7579..3dfc28202e4e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -241,12 +241,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, llvm::report_fatal_error("Illegal shared layout"); } - int vec = 8 * 16 / elemBitWidth; - if (vec != shared.getVec()) { - llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec - << ": " << shared << "\n"; - llvm::report_fatal_error("Illegal shared layout"); - } + int vec = shared.getVec(); StringAttr colDimName = outDimNames[colDim]; StringAttr rowDimName = outDimNames[rowDim]; diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 1f35d8fdd68b..83eacfa84378 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -121,13 +121,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } -// + // ----- #blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: dot_reg_operand_upcast tt.func @dot_reg_operand_upcast(%a_desc: !triton_gpu.memdesc<128x64xi8, #shared>, %b: !triton_gpu.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) { %a_dotop = triton_gpu.local_load %a_desc : !triton_gpu.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> From baf6783701c2f6ddd693a1e7e9b8f964c4fe3646 Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 25 Nov 2024 16:17:38 +0000 Subject: [PATCH 3/6] Filter hopper in Ampere targets --- python/test/unit/language/test_core.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3f5aeab8d101..373d2e8e5709 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -28,6 +28,7 @@ dtypes_with_bfloat16, is_cuda, is_interpreter, + is_hopper, is_hip, is_hip_cdna, is_hip_mi200, @@ -195,7 +196,12 @@ def is_layout_applicable(layout) -> bool: if layout in common_layouts: return True elif is_cuda(): - return isinstance(layout, MmaLayout) + mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout + if not isinstance(mma_layout, MmaLayout): + return False + if mma_layout.version[0] >= 3 and not is_hopper(): + return False + return True elif is_hip(): target_arch = triton.runtime.driver.active.get_current_target().arch if "gfx11" in target_arch: @@ -5300,9 +5306,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) @pytest.mark.parametrize("dtype", ['float16']) -@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) @pytest.mark.parametrize("interm_layout", intermediate_layouts) -@pytest.mark.parametrize("dst_layout", layouts) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): if str(src_layout) == str(dst_layout): pytest.skip() From 91e46b3edd8ed149288b8aa28d8c98173c3c57ad Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 25 Nov 2024 19:37:28 +0000 Subject: [PATCH 4/6] Simplify layoutIsOK --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 32a6ea34a019..3f22e91d6e1a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -376,26 +376,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // completed before we can remove the layoutIsOK check: // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { - if (isa(layout)) { - return !useLegacyMMAConversion; - } if (auto dotOperand = dyn_cast(layout)) { - auto parent = dotOperand.getParent(); - if (isa(parent) && useLegacyMMAConversion) { - return false; - } - if (isa(parent)) { - return true; - } - if (isa(parent)) { - return true; - } - return false; + layout = layout.getParent(); } - if (isa(layout)) { - return true; + + if (isa(layout)) { + return !useLegacyMMAConversion; } - if (isa(layout)) { + if (isa(layout)) { return true; } if (auto slice = dyn_cast(layout)) { From 68c8513894fc7b4094e37db3d61f99b68286f69a Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 25 Nov 2024 19:41:58 +0000 Subject: [PATCH 5/6] tiny --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 3f22e91d6e1a..b52aee34a431 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -377,7 +377,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { if (auto dotOperand = dyn_cast(layout)) { - layout = layout.getParent(); + layout = dotOperand.getParent(); } if (isa(layout)) { From de18e21ddf5bf03f17f779fef032d53ea87a53a0 Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 26 Nov 2024 13:27:18 +0000 Subject: [PATCH 6/6] Address reviews --- .../TritonGPUToLLVM/TargetInfoBase.h | 7 +++ .../TritonGPU/IR/LinearLayoutConversions.h | 10 ++-- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 33 ++++++------ .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 8 ++- .../TritonGPU/IR/LinearLayoutConversions.cpp | 51 +++---------------- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 9 ++++ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 5 ++ .../ConvertLayoutOpToLLVM.cpp | 16 +++--- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 41 ++++++++++++++- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 6 +++ 10 files changed, 107 insertions(+), 79 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 68f430d054e2..87db94f25f1d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -4,6 +4,7 @@ #include "triton/Conversion/MLIRTypes.h" namespace mlir::triton { + class TargetInfoBase { public: virtual bool supportMaximumMinimum() const = 0; @@ -37,6 +38,12 @@ class TargetInfoBase { pred); } + virtual bool canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const = 0; + virtual void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const = 0; diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 5140a03e78d7..7c81b2496cdf 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -241,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion( // TODO(Keren): We should replace tensorTy with a LinearLayout and the element // bit width of the tensor in the future to support more flexible tensor // encodings -std::optional -chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, - ArrayRef repShape, - ArrayRef paddedRepShape, - ArrayRef order, int swizzleByteSize); +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b52aee34a431..74a68939b72a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -494,19 +494,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // don't need to avoid duplicate writes. // Input dims: [reg, lane, warp] // Output dims: [offset, iteration] - std::optional shmemStoreLayout = - chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape, - scratchConfig.paddedRepShape, scratchConfig.order, - /*swizzleByteSize=*/0); - bool isStMatrix = shmemStoreLayout.has_value(); - if (!isStMatrix) { - shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout); - } - assert(shmemStoreLayout.has_value()); + bool isStMatrix = targetInfo.canUseStMatrix( + op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0); + LinearLayout shmemStoreLayout = + isStMatrix ? chooseStMatrixLayout( + ctx, op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0) + : srcLayout.invertAndCompose(sharedLayout); const int shmemAllocatedNumElems = getNumScratchElements(scratchConfig.paddedRepShape); - assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems); + assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems); // Layout for the load from shmem to registers. LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout); @@ -514,14 +515,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Check that the `register` fully determines the `iteration`. That is, // each thread does exactly the same reads and writes to shmem on each // iteration, just with different input/output registers. - assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock}, - {kIteration})); + assert( + shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); assert( shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); // iteration -> registers SmallVector> inRegsForIter = - collectRegsForIter(ctx, *shmemStoreLayout); + collectRegsForIter(ctx, shmemStoreLayout); SmallVector> outRegsForIter = collectRegsForIter(ctx, shmemLoadLayout); @@ -578,7 +579,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return vecAddr; }; - auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout, + auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout, {{kRegister, i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, @@ -601,11 +602,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // When using `stmatrix`, we can store `inVec` elements even if they are // not contiguous - auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut() + auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut() : scratchConfig.inVec; for (int j = 0; j < inVals.size() / iterations; j += inVec) { auto inRegSlice = inRegs[j]; - Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice); + Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice); SmallVector inValsVec; for (int k = 0; k < inVec; k++) inValsVec.push_back(inVals[inRegSlice + k]); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 82448f1efbc6..27fd26800b12 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -178,12 +178,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - auto srcLayout = cast(srcTy.getEncoding()); Attribute dstLayout = dstTy.getEncoding(); - if (isa(srcLayout) && - (isa(dstLayout) || - isSupportedDotOpLayout(srcTy, dstTy))) { + if (isa(dstLayout) || + isSupportedDotOpLayout(srcTy, dstTy)) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 3dfc28202e4e..742c4f4460b4 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -853,40 +853,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion( } namespace { - -// TODO (Keren): Currently, we have more restrictions than necessary when using -// stmatrix. These restrictions are retained from legacy code, and we could -// relax some of them in the future. -bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, - ArrayRef paddedRepShape, ArrayRef order, - int swizzleByteSize) { - auto mmaLayout = - mlir::dyn_cast(tensorTy.getEncoding()); - if (!mmaLayout || !mmaLayout.isHopper()) - return false; - if (isa(tensorTy.getElementType())) - return false; - if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) - return false; - if (order[0] != 1) - return false; - - auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape()); - if (tensorShapePerCTA.size() != 2) - return false; - auto numIterations = ceil(tensorShapePerCTA[1], repShape[1]) * - ceil(tensorShapePerCTA[0], repShape[0]); - if (numIterations > 1) - return false; - if (paddedRepShape[1] % 8 != 0) - return false; - if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && - swizzleByteSize != 128) - return false; - return true; -} - -std::optional chooseStMatrixLayoutLeadingOffset( +LinearLayout chooseStMatrixLayoutLeadingOffset( MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order, int swizzleByteSize) { @@ -957,7 +924,7 @@ std::optional chooseStMatrixLayoutLeadingOffset( .reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}}); } -std::optional chooseStMatrixLayoutNoLeadingOffset( +LinearLayout chooseStMatrixLayoutNoLeadingOffset( MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order) { StringAttr kReg = S("register"); @@ -997,15 +964,11 @@ std::optional chooseStMatrixLayoutNoLeadingOffset( } // anonymous namespace -std::optional -chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, - ArrayRef repShape, - ArrayRef paddedRepShape, - ArrayRef order, int swizzleByteSize) { - if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order, - swizzleByteSize)) - return std::nullopt; - +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) { if (swizzleByteSize == 0) return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape, paddedRepShape, order); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 525361fee603..cca1714f6581 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -85,6 +85,15 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); } +bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const { + // AMD does not support stmatrix + return false; +} + void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const { llvm::report_fatal_error("AMDGPU does not support stmatrix"); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 0ce38d4d7660..30a9f7306b98 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -27,6 +27,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Type elemTy, Value pred) const override; + + bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const override; void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const override; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 763a38058c13..d4613fef4321 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -380,11 +380,12 @@ struct LocalAllocOpConversion SmallVector shape = convertType(srcTy.getShape()); auto order = sharedLayout.getOrder(); + if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, + swizzleByteSize)) { + return failure(); + } auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, shape, order, swizzleByteSize); - if (!layout.has_value()) - return failure(); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); auto smemPtrTy = ptr_ty(ctx, 3); @@ -394,23 +395,22 @@ struct LocalAllocOpConversion auto kBlock = str_attr("block"); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(layout->getInDimSize(kLane)); + Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); Value laneId = urem(threadId, threadsPerWarp); Value warpId = udiv(threadId, threadsPerWarp); - auto regBase = applyLinearLayout(loc, rewriter, *layout, + auto regBase = applyLinearLayout(loc, rewriter, layout, {{kRegister, i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, {kBlock, i32_val(0)}})[0] .second; auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto srcVec = layout->getNumConsecutiveInOut(); + auto srcVec = layout.getNumConsecutiveInOut(); Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); for (int i = 0; i < srcVals.size(); i += srcVec) { auto regIdx = - layout - ->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] .second; Value offset = xor_(regBase, i32_val(regIdx)); auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 76b565365406..7c4a9e5b92df 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -5,7 +5,6 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "llvm/Support/MathExtras.h" using namespace mlir; @@ -468,6 +467,46 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, return false; } +// TODO (Keren): Currently, we have more restrictions than necessary when using +// stmatrix. These restrictions are retained from legacy code, and we could +// relax some of them in the future. +// TODO (Lezcano): The proper way of doing this is to directly try to fit the +// relevant layout and return an std::optional. I'm keeping this +// split to keep the current PR smaller +bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const { + if (computeCapability < 90) { + return false; + } + auto mmaLayout = + mlir::dyn_cast(tensorTy.getEncoding()); + if (!mmaLayout || !mmaLayout.isHopper()) + return false; + if (isa(tensorTy.getElementType())) + return false; + if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) + return false; + if (order[0] != 1) + return false; + + auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape()); + if (tensorShapePerCTA.size() != 2) + return false; + auto numIterations = ceil(tensorShapePerCTA[1], repShape[1]) * + ceil(tensorShapePerCTA[0], repShape[0]); + if (numIterations > 1) + return false; + if (paddedRepShape[1] % 8 != 0) + return false; + if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && + swizzleByteSize != 128) + return false; + return true; +} + void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const { auto vals = unpackLLVector(loc, val, rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index ed9bd91a8d34..eedab90c98e3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -23,6 +23,12 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Type elemTy, Value pred) const override; + + bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const override; + void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const override;