From e275dcceb842b24602d72fde52a606814bfebbf5 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Thu, 14 Nov 2024 19:07:19 +0000 Subject: [PATCH] introduce AMD inThreadTranspose for K-major dot operand --- bin/RegisterTritonDialects.h | 1 + .../TritonGPU/IR/LinearLayoutConversions.h | 8 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 99 +++++++- test/TritonGPU/amd/in-thread-transpose.mlir | 54 +++++ .../include/TritonAMDGPUTransforms/Passes.h | 2 + .../include/TritonAMDGPUTransforms/Passes.td | 15 ++ .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 + .../inThreadTranspose.cpp | 211 ++++++++++++++++++ third_party/amd/python/triton_amd.cc | 2 + .../TritonGPU/LinearLayoutConversionsTest.cpp | 33 +++ 10 files changed, 424 insertions(+), 2 deletions(-) create mode 100644 test/TritonGPU/amd/in-thread-transpose.mlir create mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index c69e46792c3d..35da847a41e3 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -66,6 +66,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUConvertToBufferOps(); mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + mlir::registerTritonAMDGPUInThreadTranspose(); // TODO: register Triton & TritonGPU passes registry.insert toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth = std::nullopt); + std::optional elemBitWidth = std::nullopt, + bool inThreadTranspose = false); // Given a linear layout where the input dimensions contain a "block" dimension, // this method sets the "block" dimension to 0 and removes the corresponding diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index aee7da8a7579..9fe08316830c 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -73,6 +73,60 @@ LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, return ret; } +// For sizePerThread = [4, 8], the regular linear layout will express it as +// the following +// - register=1 -> (1, 0) +// register=2 -> (2, 0) +// register=4 -> (4, 0) +// register=8 -> (0, 1) +// register=16 -> (0, 2) +// where out dims are: [dim1 (size 8), dim0 (size 4)] +// If we take the binary form, it will be an identity matrix. If we traverse +// from the dim of 4, it will be like the following +// - register=1 -> (0, 1) +// register=2 -> (0, 2) +// register=4 -> (1, 0) +// register=8 -> (2, 0) +// register=16 -> (4, 0) +// where out dims are: [dim1 (size 8), dim0 (size 4)] +// Inside the function we only change the register layout generation, so +// register layout is created by newly introduced transpose2D and the rest still +// comes from identityStandardND. +// Note that simply reversing the for-loop identityStandardND will not work +// because it will change the most minor dimension from dim1 to dim0, and still +// keep it as an identity matrix. +LinearLayout transpose2D(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { + assert(shape.size() == order.size()); + assert((order.size() == 2) && "only support dim of 2 now"); + + MLIRContext *ctx = inDimName.getContext(); + StringAttr kRegister = S("register"); + + std::vector> bases; + // traverse 2nd dimension (K-dim in GEMM case) + int dim = order[1]; + for (int basis = 1; basis < shape[dim]; basis <<= 1) { + bases.push_back({0, basis}); + } + // traverse 1st dimension (N-dim in GEMM non-KContig B-tensor) + // this is the consecutive dimension loaded from global memory + dim = order[0]; + for (int basis = 1; basis < shape[dim]; basis <<= 1) { + bases.push_back({basis, 0}); + } + + auto dimMinor = "dim" + std::to_string(order[0]); + auto dimMajor = "dim" + std::to_string(order[1]); + StringAttr kDimMinor = S(dimMinor); + StringAttr kDimMajor = S(dimMajor); + auto ret = LinearLayout( + {{kRegister, bases}}, + {{kDimMinor, shape[order[0]]}, {kDimMajor, shape[order[1]]}}, false); + + return ret; +} + // Make a LinearLayout that maps a block-id to an N-dimensional index. // // The tensor is split up into CTAsPerCGA pieces, which are distributed among @@ -274,6 +328,45 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); } +// This function convert blockedEncodingAttr to linear layout in a special way. +// It accompanies the AMDGPUInThreadTranspose pass to transpose non-KContig +// tensor into KContig prior to writing into LDS (shared memory). This +// conversion treats the sizePerThread as a 2D matrix and has different access +// pattern. +// +// For example, consider the following blocked layout generated by +// AMDGPUInThreadTranspose: #blocked1 = #triton_gpu.blocked<{sizePerThread = +// [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>. +// Here since sizePerThread is 2D, there could be two ways to traverse it: along +// the dim of 8 or the dim of 4. The regular toLinearLayout() would go through +// it from the leading order, i.e. dim of 8, but since we want to transpose it +// in-thread, we'd want to iterate of the 2nd order, i.e. dim of 4, so that we +// can pack the element of 4 into a single vector, and AMD backend LLVM compiler +// will pack elements into consecutive VGPR to write data contiguous in K +// dimension into LDS. In this way we guarantee vectorized ds_read, and ds_write +// can be vectorized to 64bit or 32bit depending on the block size and number of +// warps. +// +// The functions is named ThreadRake because we have thread raking through +// multiple row at the same time, as opposed each warp raking through a cluster +// of rows, or the Triton way, which iterates through every warp avaiable, +// and then tile it over the entire block. +LinearLayout blockedToLinearLayoutThreadRake(ArrayRef shape, + BlockedEncodingAttr blocked) { + MLIRContext *ctx = blocked.getContext(); + int rank = shape.size(); + auto outDimNames = standardOutDimNames(ctx, rank); + const auto &order = blocked.getOrder(); + auto sizePerThread = blocked.getSizePerThread(); + + auto ctaLayout = + transpose2D(S("register"), sizePerThread, order) * + identityStandardND(S("lane"), blocked.getThreadsPerWarp(), order) * + identityStandardND(S("warp"), blocked.getWarpsPerCTA(), order); + + return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); +} + } // anonymous namespace std::optional @@ -790,9 +883,13 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth /*= std::nullopt*/) { + std::optional elemBitWidth /*= std::nullopt*/, + bool inThreadTranspose /*= false*/) { // Layouts are distributed or shared if (auto distributed = dyn_cast(layout)) { + auto blocked = dyn_cast(distributed); + if (blocked && inThreadTranspose) + return blockedToLinearLayoutThreadRake(shape, blocked); return distributed.toLinearLayout(shape); } else if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { diff --git a/test/TritonGPU/amd/in-thread-transpose.mlir b/test/TritonGPU/amd/in-thread-transpose.mlir new file mode 100644 index 000000000000..747bfd6477f6 --- /dev/null +++ b/test/TritonGPU/amd/in-thread-transpose.mlir @@ -0,0 +1,54 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-in-thread-transpose | FileCheck %s + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + +// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x256x!tt.ptr, [[threadrake_layout]]> +// CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<64x256x!tt.ptr, [[threadrake_layout]]> + tt.func public @threadRake_transpose_b(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1 = tt.load %arg1 : tensor<64x256x!tt.ptr, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + tt.return + } +} + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + +// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<32x128x!tt.ptr, [[threadrake_layout]]> +// CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<32x128x!tt.ptr, [[threadrake_layout]]> + tt.func public @threadRake_transpose_b_no_change(%arg0: tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x128x!tt.ptr, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %1 = tt.load %arg1 : tensor<32x128x!tt.ptr, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<32x128xf16, #blocked1> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + tt.return + } +} + + +// ----- +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + +// CHECK-NOT: {{.*}} = triton_gpu.convert_layout {{.*blocked.*}} -> {{.*blocked.*}} + tt.func public @threadRake_no_transpose(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1 = tt.load %arg1 : tensor<64x256x!tt.ptr, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + tt.return + } +} diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 636743d305f9..066889077c4d 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -27,6 +27,8 @@ std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass(); +std::unique_ptr createTritonAMDGPUInThreadTransposePass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "TritonAMDGPUTransforms/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 85604dcaca18..8a0cfdb00690 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -124,4 +124,19 @@ def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "ml let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; } +def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mlir::ModuleOp"> { + let summary = "Transpose K-outer dot operand while data is loaded into register right before writing to LDS"; + + let description = [{ + Transpose non-KContig dot operand (not consecutive on K dimension) right before writing data into LDS. This feature + happens right after data has been loaded from global memory to thread-local registers and will promote + (does not guarantee) vectorized LDS write while let SharedEncodingAttr guarantee vectorized LDS read, by + adding few VALU instructions to perform in-thread transpose. + }]; + + let constructor = "mlir::createTritonAMDGPUInThreadTransposePass()"; + + let dependentDialects = []; +} + #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index c3a69a5f9a2a..e1c49bf58659 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_triton_library(TritonAMDGPUTransforms ReorderInstructions.cpp StreamPipelineV2.cpp MfmaGroup.cpp + inThreadTranspose.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp new file mode 100644 index 000000000000..f8aa60e6a5d1 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp @@ -0,0 +1,211 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritonamdgpu-in-thread-transpose" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = dyn_cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +// This function is mostly copied over from coalesce.cpp since it uses almost +// the same functionality. +void convertLayout(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + newTypes.push_back(getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(), + newArgs, newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); +} + +SmallVector getLoadInsts(Operation *op) { + SmallVector ret; + auto v = op->getOperand(0); + auto prevOp = v.getDefiningOp(); + if (isa(prevOp)) { + // Deal with the case that convert_layout intakes from scf.if, etc. + LDBG("Dealing with scf blocks"); + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + prevOp->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) { + yieldOps.push_back(yieldOp); + } + }); + + for (auto yieldOp : yieldOps) { + auto maybeLoadOp = yieldOp.getOperand(idx).getDefiningOp(); + if (isa(maybeLoadOp)) + ret.push_back(maybeLoadOp); + } + } else if (isa(prevOp)) { + // regular case + LDBG("Regular cases"); + ret.push_back(prevOp); + } else { + // can't find any loadOp + LDBG("we assume load->convert_layout->dot chain but we cannot find it."); + } + return ret; +} + +bool needCvtToThreadRaked(Value operand) { + auto opTensorTy = cast(operand.getType()); + auto opEnc = opTensorTy.getEncoding(); + auto opDotOpEnc = dyn_cast(opEnc); + // dotOperand has to have dotOp and MFMA encoding + if (!opDotOpEnc) + return false; + if (!isa(opDotOpEnc.getParent())) { + LDBG("Operand's parent encoding is not MFMA"); + return false; + } + auto cvtOp = operand.getDefiningOp(); + // make sure the previous op is convert_layout + if (!cvtOp || !isa(cvtOp)) + return false; + auto cvtOperand = cvtOp->getOperand(0); + auto cvtOperandEnc = + cast(cvtOperand.getType()).getEncoding(); + auto blockedEnc = dyn_cast(cvtOperandEnc); + // make sure it is converted from blocked layout + if (!blockedEnc) + return false; + // check whether it's contiguous on K dimension + int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; + auto order = blockedEnc.getOrder(); + if (order[0] != kDimNum) { + return true; + } + + return false; +} + +ttg::BlockedEncodingAttr getThreadRakedBlockedEnc(Value operand, + ModuleOp &mod) { + // get the K dim according to dotOp operand's index + auto tensorTy = cast(operand.getType()); + auto shape = tensorTy.getShape(); + auto opEnc = tensorTy.getEncoding(); + auto opDotOpEnc = dyn_cast(opEnc); + int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; + // get the current blocked encoding + auto cvtOperand = operand.getDefiningOp()->getOperand(0); + auto cvtOperandEnc = + cast(cvtOperand.getType()).getEncoding(); + auto blockedEnc = dyn_cast(cvtOperandEnc); + // compute the sizePerThread for the new encoding + auto sizePerThread = blockedEnc.getSizePerThread(); + auto elemsPerIter = product(sizePerThread); + auto elemsTotal = blockedEnc.getTotalElemsPerThread(shape, tensorTy); + // we need to know how many iteration each thread will load + LDBG("elemsPerIter = " << elemsPerIter << "; elemsTotal = " << elemsTotal); + auto numMaxIters = elemsTotal / elemsPerIter; + auto bitwidth = tensorTy.getElementType().getIntOrFloatBitWidth(); + // Current the widest is set to ds_write_b64 + auto newKOuterDim = std::min(numMaxIters, 64 / bitwidth); + LDBG("Choose the minimum of numIters: " << numMaxIters << " and numDtype: " + << 64 / bitwidth); + SmallVector newSizePerThread(sizePerThread); + newSizePerThread[kDimNum] = newKOuterDim; + + // return the new blocked encoding + auto order = blockedEnc.getOrder(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + return ttg::BlockedEncodingAttr::get(mod.getContext(), shape, + newSizePerThread, order, numWarps, + threadsPerWarp, numCTAs); +} + +class TritonAMDGPUInThreadTransposePass + : public TritonAMDGPUInThreadTransposeBase< + TritonAMDGPUInThreadTransposePass> { + +public: + TritonAMDGPUInThreadTransposePass() = default; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + m.walk([&](Operation *op) { + auto dotOp = dyn_cast(op); + if (!dotOp) + return; + + LDBG("DotOp under inspection: " << dotOp); + auto mod = op->getParentOfType(); + + // helper function + auto cvtNonKContigDotOperand = [&](Value op) { + if (needCvtToThreadRaked(op)) { + auto loadOps = getLoadInsts(op.getDefiningOp()); + // when we cannot find the associated loadOp + if (!loadOps.size()) + return; + auto newBlockedEnc = getThreadRakedBlockedEnc(op, mod); + LDBG("newBlockedEnc = " << newBlockedEnc); + for (auto loadOp : loadOps) + convertLayout(newBlockedEnc, (Operation *)loadOp); + } + }; + + cvtNonKContigDotOperand(dotOp.getA()); + cvtNonKContigDotOperand(dotOp.getB()); + }); + } +}; + +std::unique_ptr mlir::createTritonAMDGPUInThreadTransposePass() { + return std::make_unique(); +} diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a9bd3e9b7fb7..1ea5f1734f88 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -74,6 +74,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_2("add_stream_pipelinev2", mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int); + ADD_PASS_WRAPPER_0("add_in_thread_tranpose", + mlir::createTritonAMDGPUInThreadTransposePass); } void addControlConstant(llvm::Module *module, const char *name, diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index af6242b59662..59a3b1959ea6 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -307,6 +307,39 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, inThreadTranspose_4x8) { + auto ll = toLinearLayout( + {64, 256}, + blocked({4, 8}, {2, 32}, {8, 1}, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + std::nullopt, true); + EXPECT_EQ(ll, + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), + {{0, 8}, {0, 16}, {0, 32}, {0, 64}, {0, 128}, {4, 0}}}, + {S("warp"), {{8, 0}, {16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, inThreadTranspose_1x8) { + auto ll = toLinearLayout( + {32, 128}, + blocked({1, 8}, {4, 16}, {8, 1}, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + std::nullopt, true); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), + {{0, 8}, {0, 16}, {0, 32}, {0, 64}, {1, 0}, {2, 0}}}, + {S("warp"), {{4, 0}, {8, 0}, {16, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { EXPECT_EQ(toLinearLayout({16, 16}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})),