From 242cd5a9cd0647649181547d40e752c6235bb8d6 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 20 Dec 2024 20:47:51 +0000 Subject: [PATCH] Revert "Revert "[Backend] Improve dot support to target FMA (#4516)"" This reverts commit c5f5ac1fd98b982caf76fd218948fc5fce84f4fb. --- .../Conversion/TritonGPUToLLVM/Utility.h | 24 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 6 + .../SharedToDotOperandFMA.cpp | 495 +++++++++++------- .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 118 ++--- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 40 ++ .../TritonToTritonGPUPass.cpp | 5 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 73 ++- .../test/unit/language/test_compile_errors.py | 14 +- python/test/unit/language/test_core.py | 25 +- .../amd/accelerate-amd-matmul-mfma.mlir | 17 + third_party/amd/backend/compiler.py | 29 +- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 4 +- .../SharedToDotOperandMMAv2OrV3.cpp | 17 - 13 files changed, 573 insertions(+), 294 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 999355b77b..0f7b983929 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -348,12 +348,18 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, Value linear, ArrayRef shape); +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order); + Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape, ArrayRef order); Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape); +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order); + Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, StringRef content); @@ -496,6 +502,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, return ret; } +/// Extend 2d shared object to 3d. +/// +/// If tensor has 3 dimensions, returns original shared object. +/// If tensor shape is [M, N], return shared object describing shape [1, M, N] +/// +/// This Function is used to simplify processing of 2d and 3d dot operands, +/// particularly in the conversion of local_load operation. +/// +/// \param rewriter +/// \param loc +/// \param smemObj +/// \param shape shape of a tensor represented by smemObj +/// \returns shared object describing 3d tensor +SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape); + // ----------------------------------------------------------------------- // Blocked layout indices // ----------------------------------------------------------------------- diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index e592a9d6d1..b81ecf103a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType); // Return a string representation of the layout of the tensor. std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); +template +llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s); + +llvm::SmallVector +expandMatrixOrderWithBatch(llvm::ArrayRef o); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index 5bc5a2a204..0ea294d53b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -1,11 +1,14 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" using ValueTable = std::map, Value>; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; @@ -14,47 +17,6 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::SharedEncodingAttr; -SmallVector -getThreadIds(Value threadId, ArrayRef shapePerCTATile, - ArrayRef sizePerThread, ArrayRef order, - ConversionPatternRewriter &rewriter, Location loc) { - int dim = order.size(); - SmallVector threadIds(dim); - for (unsigned k = 0; k < dim - 1; k++) { - Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); - Value rem = urem(threadId, dimK); - threadId = udiv(threadId, dimK); - threadIds[order[k]] = rem; - } - Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); - threadIds[order[dim - 1]] = urem(threadId, dimK); - return threadIds; -} - -// Get shapePerCTATile for M or N axis. -int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto shapePerCTATile = getShapePerCTATile(layout); - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - return isM ? mShapePerCTATile : nShapePerCTATile; -} - -// Get sizePerThread for M or N axis. -int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto sizePerThread = getSizePerThread(layout); - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - return isM ? mSizePerThread : nSizePerThread; -} - Value getStructFromValueTable(ArrayRef vals, ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter *typeConverter, @@ -70,154 +32,329 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } -ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, - int sizePerThread, - ConversionPatternRewriter &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter, - Type type) { - ValueTable res; - auto elems = unpackLLElements(loc, val, rewriter); - int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTA) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } - } - return res; +bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } + +SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, + Location loc, SmallVector rawIndices, + SharedEncodingAttr layout) { + const auto &order = layout.getOrder(); + auto rank = order.size(); + + if (!isSwizzled(layout)) + return rawIndices; + + auto vec = i32_val(layout.getVec()); + auto perPhase = i32_val(layout.getPerPhase()); + auto maxPhase = i32_val(layout.getMaxPhase()); + + auto fastIdx = rawIndices[order[0]]; + auto secondIdx = rawIndices[order[1]]; + // Original algorithm taken from getSwizzledSharedPtrs function + // (TritonGPUToLLVMBase.h) + // + // phase = (secondIdx // perPhase) % maxPhase + // swizzledGroup = ((fastIdx // vec) ^ phase) * vec + // groupRemainder = fastIdx % vec + // colOff = swizzledGroup + groupRemainder + auto phase = urem(udiv(secondIdx, perPhase), maxPhase); + auto swizzledGroup = mul(xor_(udiv(fastIdx, vec), phase), vec); + auto groupRemainder = urem(fastIdx, vec); + auto colOff = add(swizzledGroup, groupRemainder); + + SmallVector swizzledIndices = rawIndices; + swizzledIndices[order[0]] = colOff; + + return swizzledIndices; } -Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, - Location loc, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto aTensorTy = cast(A.getType()); - auto aLayout = cast(aTensorTy.getEncoding()); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - - auto aOrder = aLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isARow = aOrder[0] == 1; - - auto aSmem = getSharedMemoryObjectFromStruct( - loc, llA, typeConverter->convertType(aTensorTy.getElementType()), - rewriter); - Value strideAM = aSmem.strides[0]; - Value strideAK = aSmem.strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; - Value strideA1 = isARow ? strideAM : strideAK; - int aNumPtr = 8; - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - - auto shapePerCTATile = getShapePerCTATile(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value mContig = i32_val(sizePerThread[order[1]]); - - // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, - rewriter, loc); - Value threadIdM = threadIds[0]; - - Value offA0 = isARow ? _0 : mul(threadIdM, mContig); - Value offA1 = isARow ? mul(threadIdM, mContig) : _0; - SmallVector aOff(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) { - aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); +struct DimIdx { + unsigned batch; + unsigned k; + unsigned nonK; +}; + +/// Put elements from Value vec to appropriate indexes in opValues array. +/// +/// This function maps elements of 3d sub-tensor in linear array. +/// Axes are arranged in an order provided "opOrder" argument +void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, + SmallVector &opValues, Value vec, + ArrayRef perThreadTileShape, + unsigned kIdx, unsigned nonKIdx, unsigned bIdx, + const DimIdx &dim, int vecDim, + ArrayRef opOrder) { + auto vecTy = cast(vec.getType()); + auto vectorSize = vecTy.getNumElements(); + auto elemTy = vecTy.getElementType(); + for (int elem = 0; elem < vectorSize; ++elem) { + unsigned spatialIdx[3] = {}; + spatialIdx[dim.batch] = bIdx; + spatialIdx[dim.k] = kIdx; + spatialIdx[dim.nonK] = nonKIdx; + spatialIdx[vecDim] += elem; + + unsigned linearIdx = linearize(spatialIdx, perThreadTileShape, opOrder); + opValues[linearIdx] = extract_element(elemTy, vec, i32_val(elem)); } - auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); - - Type ptrTy = aSmem.base.getType(); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) - aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); - - SmallVector vas; - - int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); - int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); +} - for (unsigned k = 0; k < K; ++k) - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) { - Value offset = - add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); - Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); - Value va = load(elemTy, pa); - vas.emplace_back(va); - } +void verifyCTALayout(CTALayoutAttr ctaLayout) { + auto ctaSplit = ctaLayout.getCTASplitNum(); + for (auto split : ctaSplit) { + if (split != 1) + llvm::report_fatal_error("tensors splited in CGA(thread group clusters) " + "are not supported in FMA dot yet."); + } +} - return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); +/// Get a linear offset of first element loaded by thread. +/// +/// In unswizzled case offset of any element computed with formula: +/// smem.base + first_element_offset + constant_offset. +/// +/// first_element_offset depends on lane Id and warp Id +/// constant_offset depends on value number, which is same for all threads. +/// \returns first_element_offset +Value getUnswizzledFirstElemOffset(ConversionPatternRewriter &rewriter, + Location loc, unsigned B, unsigned NonK, + Value bTileOffset, Value nonKTileOffset, + Value bStride, Value nonKStride) { + auto bOffset = mul(urem(bTileOffset, i32_val(B)), bStride); + auto nonKOffset = mul(urem(nonKTileOffset, i32_val(NonK)), nonKStride); + Value threadIdDependantOffset = add(bOffset, nonKOffset); + return threadIdDependantOffset; } -Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, - Location loc, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto bTensorTy = cast(B.getType()); - auto bLayout = cast(bTensorTy.getEncoding()); - auto bShapePerCTA = getShapePerCTA(bTensorTy); +/// \returns number of elements stored by one thread across each dimension +SmallVector getElemsPerThreadInOp(ArrayRef opTensorShape, + ArrayRef shapePerCTATile, + ArrayRef sizePerThread) { + int rank = opTensorShape.size(); + SmallVector elemsPerThread(rank); + for (int d = 0; d < rank; ++d) { + auto numReps = + ceil(static_cast(opTensorShape[d]), shapePerCTATile[d]); + elemsPerThread[d] = numReps * sizePerThread[d]; + } + return elemsPerThread; +} - auto bOrder = bLayout.getOrder(); - auto order = dLayout.getOrder(); +struct Indexes { + unsigned bTile; + unsigned b; + unsigned k; + unsigned nonKTile; + unsigned nonK; +}; + +/// Computes a linear memory offset of a given element relative to +/// beginning of shared memory object. +Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc, + const Indexes &i, const DimIdx &dim, + Value bTileOffset, Value nonKTileOffset, + unsigned shapePerCTABTile, + unsigned shapePerCTANonKTile, + SharedEncodingAttr sharedLayout, + ArrayRef opTensorShape, + ArrayRef strides) { + Value offset = i32_val(0); + // Compute unswizzled multi dim coordinates in shared memmory object + SmallVector elemMultiDimIndices(3); + elemMultiDimIndices[dim.batch] = + add(bTileOffset, i32_val(i.bTile * shapePerCTABTile + i.b)); + elemMultiDimIndices[dim.nonK] = + add(nonKTileOffset, i32_val(i.nonKTile * shapePerCTANonKTile + i.nonK)); + elemMultiDimIndices[dim.k] = i32_val(i.k); + + // Apply swizzling pattern to fastest dimension + SmallVector swizzledIndices = + swizzleIndices(rewriter, loc, elemMultiDimIndices, sharedLayout); + + // Linearize shared mem object dimensions into flat offset + for (int d = 0; d < 3; ++d) { + // wrap index if it is larger than tensor + auto wrappedDimIndex = urem(swizzledIndices[d], i32_val(opTensorShape[d])); + auto dimOffset = mul(wrappedDimIndex, strides[d]); + offset = add(offset, dimOffset); + } + return offset; +} - bool isBRow = bOrder[0] == 1; +/// Computes memory offset of a given element relative to the +/// first element loaded by a thread. +Value computeNonSwizzledOffset(ConversionPatternRewriter &rewriter, + Location loc, const Indexes &i, + const DimIdx &dim, ArrayRef tensorShape, + unsigned shapePerCTABTile, + unsigned shapePerCTANonKTile, + ArrayRef strides) { + SmallVector offsetIndices(3); + offsetIndices[dim.batch] = + i32_val((i.bTile * shapePerCTABTile + i.b) % tensorShape[dim.batch]); + offsetIndices[dim.nonK] = i32_val( + (i.nonKTile * shapePerCTANonKTile + i.nonK) % tensorShape[dim.nonK]); + offsetIndices[dim.k] = i32_val(i.k); + + Value offset = i32_val(0); + for (int d = 0; d < 3; ++d) + offset = add(offset, mul(offsetIndices[d], strides[d])); + return offset; +} - auto bSmem = getSharedMemoryObjectFromStruct( - loc, llB, typeConverter->convertType(bTensorTy.getElementType()), +/// Generates llvm IR for loading FMA dot operand from shared memory. +/// +/// \param srcVal triton_gpu MemDescType value +/// \param llVal llvm IR values corresponding to srcVal +/// \param dLayout parent dot operand layout +/// \param thread thread id +/// \param loc +/// \param typeConverter +/// \param rewriter +/// \param dotOpNo +/// \returns llvm value with loaded elements +Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, + Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const int dotOpNo) { + verifyCTALayout(dLayout.getCTALayout()); + + DimIdx dim; + dim.batch = 0; + dim.k = dotOpNo == 0 ? 2 : 1; + dim.nonK = dotOpNo == 0 ? 1 : 2; + auto opTensorTy = cast(srcVal.getType()); + auto opTensorShape = expandMatrixShapeWithBatch(opTensorTy.getShape()); + auto sharedLayout = cast(opTensorTy.getEncoding()); + + auto opOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + + auto origSmem = getSharedMemoryObjectFromStruct( + loc, llVal, typeConverter->convertType(opTensorTy.getElementType()), rewriter); - Value strideBN = bSmem.strides[1]; - Value strideBK = bSmem.strides[0]; - Value strideB0 = isBRow ? strideBN : strideBK; - Value strideB1 = isBRow ? strideBK : strideBN; - int bNumPtr = 8; - int K = bShapePerCTA[0]; - int N = bShapePerCTA[1]; - - auto shapePerCTATile = getShapePerCTATile(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value nContig = i32_val(sizePerThread[order[0]]); - - // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, - rewriter, loc); - Value threadIdN = threadIds[1]; - - Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; - Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); - SmallVector bOff(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) { - bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); + auto smem = getExpandedSharedMemoryObject(rewriter, loc, origSmem, + opTensorTy.getShape()); + auto strides = smem.strides; + int B = opTensorShape[dim.batch]; + int K = opTensorShape[dim.k]; + int NonK = opTensorShape[dim.nonK]; + + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + shapePerCTATile[dim.k] = K; + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + sizePerThread[dim.k] = K; + auto threadsPerWarp = + expandMatrixShapeWithBatch(ArrayRef(dLayout.getThreadsPerWarp())); + auto warpsPerCTA = + expandMatrixShapeWithBatch(ArrayRef(dLayout.getWarpsPerCTA())); + + auto warpSize = i32_val(triton::gpu::getWarpSize(dLayout)); + auto laneId = urem(thread, warpSize); + auto warpId = udiv(thread, warpSize); + auto laneIds = + mlir::LLVM::delinearize(rewriter, loc, laneId, threadsPerWarp, opOrder); + auto warpIds = + mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, opOrder); + auto sizePerWarpB = sizePerThread[dim.batch] * threadsPerWarp[dim.batch]; + auto sizePerWarpNonK = sizePerThread[dim.nonK] * threadsPerWarp[dim.nonK]; + + Value bTileOffset = + mul(laneIds[dim.batch], i32_val(sizePerThread[dim.batch])); + bTileOffset = + add(bTileOffset, mul(warpIds[dim.batch], i32_val(sizePerWarpB))); + Value nonKTileOffset = + mul(laneIds[dim.nonK], i32_val(sizePerThread[dim.nonK])); + nonKTileOffset = + add(nonKTileOffset, mul(warpIds[dim.nonK], i32_val(sizePerWarpNonK))); + + auto elemTy = typeConverter->convertType(opTensorTy.getElementType()); + Type ptrTy = smem.base.getType(); + + auto sharedOrder = expandMatrixOrderWithBatch(sharedLayout.getOrder()); + // compute contiguity of fastest dimension in shared layout. + unsigned vectorSize = sizePerThread[sharedOrder[0]]; + vectorSize = std::min(vectorSize, 128 / elemTy.getIntOrFloatBitWidth()); + + bool swizzlePath = isSwizzled(sharedLayout); + + if (swizzlePath) + vectorSize = std::min(vectorSize, sharedLayout.getVec()); + auto vecTy = vec_ty(elemTy, vectorSize); + // loop increments depend on fastest dim + unsigned dimStep[3] = {1, 1, 1}; + dimStep[sharedOrder[0]] = vectorSize; + + auto shapePerCTABTile = shapePerCTATile[dim.batch]; + auto shapePerCTANonKTile = shapePerCTATile[dim.nonK]; + auto sizeBPerThread = sizePerThread[dim.batch]; + auto sizeNonKPerThread = sizePerThread[dim.nonK]; + auto numBTiles = std::max(1u, B / shapePerCTABTile); + auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile); + + auto perThreadShape = + getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread); + + SmallVector opValues(numBTiles * sizeBPerThread * K * numNonKTiles * + sizeNonKPerThread); + + // In swizzled memory case basePtr stores pointer to the beginning of shared + // memmory object. + // + // If memory is not swizzled, algorithm breaks element offset pointer into + // constant and non-constant part. Non-constant (depends on thread id) part is + // the offset of the first element of the thread, which is same for all + // elements of the thread. It is computed only once. basePtr stores this + // non-constant part + Value basePtr; + if (swizzlePath) { + basePtr = smem.base; + } else { + auto laneOffset = getUnswizzledFirstElemOffset( + rewriter, loc, B, NonK, bTileOffset, nonKTileOffset, strides[dim.batch], + strides[dim.nonK]); + basePtr = gep(ptrTy, elemTy, smem.base, laneOffset); } - auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); - - Type ptrTy = bSmem.base.getType(); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) - bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); - - SmallVector vbs; - - int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); - int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); - - for (unsigned k = 0; k < K; ++k) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - Value offset = - add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); - Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); - Value vb = load(elemTy, pb); - vbs.emplace_back(vb); - } - return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); + // This loop nest iterates over all values loaded in one thread across batch, + // k and nonK dimensions. Blocked dot operand layout allocates data in tiles + // of size ** for batch and nonK + // dimensions. If tensor shape is larger than tile, pattern repeats. To take + // these repeats into account iterations for batch and nonK are split into + // "intra tile" + "inter tile" indexes: b + bTile, nonK + nonKTile + for (unsigned bTile = 0; bTile < numBTiles; ++bTile) + for (unsigned b = 0; b < sizeBPerThread; b += dimStep[dim.batch]) + for (unsigned k = 0; k < K; k += dimStep[dim.k]) + for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) + for (unsigned nonK = 0; nonK < sizeNonKPerThread; + nonK += dimStep[dim.nonK]) { + Value offset = i32_val(0); + Indexes idx = {bTile, b, k, nonKTile, nonK}; + + // swizzled variant is more general, but it limits optimization of + // address computation, + if (swizzlePath) { + offset = computeSwizzledOffset( + rewriter, loc, idx, dim, bTileOffset, nonKTileOffset, + shapePerCTABTile, shapePerCTANonKTile, sharedLayout, + opTensorShape, strides); + } else { + offset = computeNonSwizzledOffset(rewriter, loc, idx, dim, + opTensorShape, shapePerCTABTile, + shapePerCTANonKTile, strides); + } + + Value elemAddr = gep(ptrTy, elemTy, basePtr, offset); + Value vec = load(vecTy, elemAddr); + storeValuesInLinearVector( + rewriter, loc, opValues, vec, perThreadShape, /*kIdx*/ k, + /*nonKIdx*/ nonKTile * sizeNonKPerThread + nonK, + /*bIdx*/ bTile * sizeBPerThread + b, dim, sharedOrder[0], + opOrder); + } + + return getStructFromValueTable(opValues, rewriter, loc, typeConverter, + elemTy); } namespace SharedToDotOperandFMA { @@ -225,9 +362,7 @@ Value convertLayout(int opIdx, Value val, Value llVal, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { - if (opIdx == 0) - return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); - else - return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + return loadFMAOp(val, llVal, dLayout, thread, loc, typeConverter, rewriter, + opIdx); } } // namespace SharedToDotOperandFMA diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index afb5bf01d4..e32b3e0d6e 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -1,28 +1,36 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; using namespace mlir::triton; +using namespace ::mlir::triton::gpu; -using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::getSizePerThread; -using ValueTableFMA = std::map, Value>; +using ValueTableFMA = std::map, Value>; static ValueTableFMA -getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, - int sizePerThread, +getValueTableFromStructFMA(Value val, ArrayRef perTileShape, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, Type type) { + ArrayRef order) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); - int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTATile) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } + assert(perTileShape.size() == 3); + assert(elems.size() == product(perTileShape)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order); + res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx]; } return res; } @@ -34,68 +42,60 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto loc = op.getLoc(); auto A = op.getA(); - auto B = op.getB(); - auto C = op.getC(); auto D = op.getResult(); auto aTensorTy = cast(A.getType()); - auto bTensorTy = cast(B.getType()); auto dTensorTy = cast(D.getType()); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - auto bShapePerCTA = getShapePerCTA(bTensorTy); + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); BlockedEncodingAttr dLayout = cast(dTensorTy.getEncoding()); - auto order = dLayout.getOrder(); + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); Value llA = adaptor.getA(); Value llB = adaptor.getB(); - auto sizePerThread = getSizePerThread(dLayout); - auto shapePerCTATile = getShapePerCTATile(dLayout); - - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - int N = bShapePerCTA[1]; - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - - auto has = - getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, - rewriter, loc, typeConverter, aTensorTy); - auto hbs = - getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, - rewriter, loc, typeConverter, bTensorTy); - - SmallVector ret = cc; - bool isCRow = order[0] == 1; - - for (unsigned k = 0; k < K; k++) { - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - int mIdx = m / mShapePerCTATile * mSizePerThread + mm; - int nIdx = n / nShapePerCTATile * nSizePerThread + nn; - - int z = isCRow - ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx - : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; - ret[z] = rewriter.create(loc, has[{m + mm, k}], - hbs[{n + nn, k}], ret[z]); - } + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + + unsigned K = aShapePerCTA[2]; + + unsigned perThreadShape[3]; + for (int i = 0; i < 3; ++i) { + unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; + numRep = std::max(static_cast(1), numRep); + perThreadShape[i] = numRep * sizePerThread[i]; } - auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); + auto has = getValueTableFromStructFMA( + llA, {perThreadShape[0], perThreadShape[1], K}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order); + auto hbs = getValueTableFromStructFMA( + llB, {perThreadShape[0], K, perThreadShape[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order); + + SmallVector acc = cc; + + for (unsigned b = 0; b < perThreadShape[0]; ++b) + for (unsigned m = 0; m < perThreadShape[1]; ++m) + for (unsigned n = 0; n < perThreadShape[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearAccumIdx = + linearize(multiDimAccumIdx, perThreadShape, order); + for (unsigned k = 0; k < K; ++k) { + acc[linearAccumIdx] = rewriter.create( + loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]); + } + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); rewriter.replaceOp(op, res); return success(); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 554810fa4e..05dcbe1c2d 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -637,6 +637,19 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, return multiDim; } +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order) { + auto rank = shape.size(); + assert(order.size() == rank); + SmallVector multiDim(rank); + for (auto dim : order) { + multiDim[dim] = linear % shape[dim]; + linear /= shape[dim]; + } + assert(linear == 0); + return multiDim; +} + Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape, ArrayRef order) { return linearize(rewriter, loc, applyPermutation(multiDim, order), @@ -658,6 +671,14 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, return linear; } +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + size_t linear = 0; + for (unsigned dim : llvm::reverse(order)) + linear = linear * shape[dim] + multiDim[dim]; + return linear; +} + Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, StringRef content) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -895,4 +916,23 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, }; } // namespace LLVM + +SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape) { + assert(shape.size() == 2 || shape.size() == 3); + auto strides = smemObj.getStrides(); + auto offsets = smemObj.getOffsets(); + auto rank = strides.size(); + assert(rank == shape.size()); + if (rank == 3) + return smemObj; + strides.insert(strides.begin(), i32_val(shape[0] * shape[1])); + offsets.insert(offsets.begin(), i32_val(0)); + auto expandedSmemObj = SharedMemoryObject( + smemObj.getBase(), smemObj.getBaseElemType(), strides, offsets); + return expandedSmemObj; +} + } // namespace mlir diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 464b150dc1..67ab63beb7 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -227,6 +227,11 @@ struct TritonDotPattern : public OpConversionPattern { retSizePerThread[rank - 1] = 4; retSizePerThread[rank - 2] = 4; } + retSizePerThread[rank - 1] = std::min( + retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); + retSizePerThread[rank - 2] = std::min( + retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); + SmallVector retOrder(rank); for (unsigned i = 0; i < rank; ++i) retOrder[i] = rank - 1 - i; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index adac291307..3a984f8bd8 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1085,29 +1085,26 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } } if (auto blockedLayout = mlir::dyn_cast(getParent())) { - auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = getShapePerCTATile(blockedLayout); - auto order = blockedLayout.getOrder(); - auto sizePerThread = blockedLayout.getSizePerThread(); - - int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; - int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; - - bool isM = getOpIdx() == 0; - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; - - return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; + auto shapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(*this, shape))); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(blockedLayout))); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(blockedLayout.getSizePerThread())); + + int batchDim = 0; + int kDim = getOpIdx() == 0 ? 2 : 1; + int nonKDim = getOpIdx() == 0 ? 1 : 2; + + int batchSize = + std::max(shapePerCTA[batchDim] / shapePerCTATile[batchDim], 1) * + sizePerThread[batchDim]; + int kSize = shapePerCTA[kDim]; + int nonKSize = + std::max(shapePerCTA[nonKDim] / shapePerCTATile[nonKDim], 1) * + sizePerThread[nonKDim]; + + return batchSize * kSize * nonKSize; } llvm_unreachable("unknown dot operand parent layout"); return 0; @@ -3382,6 +3379,36 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, return layoutStr; } +template +llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch(llvm::ArrayRef s) { + auto rank = s.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(s); + return {1, s[0], s[1]}; +} + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +llvm::SmallVector +mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef o) { + int rank = o.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(o); + llvm::SmallVector expanded(3, 0); + for (int i = 0; i < rank; ++i) + expanded[i] += o[i] + 1; + return expanded; +} + std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView) { auto layout = tensorType.getEncoding(); diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index f0ab578cbd..30537a462e 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -389,14 +389,22 @@ def test_min_dot_size(dtype): else: error_msg = "M >= 16, N >= 16 and K >= 16" elif is_hip_mi300(): - if dtype.is_int8(): + if dtype == tl.float16: + pytest.skip("fp16 FMA path supports all sizes") + elif dtype == tl.int8: error_msg += "M >= 16, N >= 16 and K >= 16" else: error_msg += "M >= 16, N >= 16 and K >= 8" elif is_hip_mi200(): - error_msg += "M >= 16, N >= 16 and K >= 8" + if dtype == tl.float16: + pytest.skip("fp16 FMA path supports all sizes") + else: + error_msg += "M >= 16, N >= 16 and K >= 8" elif is_hip(): - error_msg = "M >= 16, N >= 16 and K >= 16" + if dtype == tl.float16: + pytest.skip("fp16 FMA path supports all sizes") + else: + error_msg = "M >= 16, N >= 16 and K >= 16" else: pytest.skip("Test only supported on CUDA and HIP") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 345178ccf6..4c852fb717 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3297,13 +3297,21 @@ def convert_fp8_to_fp32(x, device, dtype_str): ([(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1), (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1)] if "gfx9" in get_arch() else []) + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) - for float8_type in ["float8e5", "float8e4nv"]]) + for float8_type in ["float8e5", "float8e4nv"]] + + [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1) + for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): if is_interpreter(): + if M < 16 or N < 16 or K < 16: + pytest.skip("small dots are supported only on HIP at the moment") if in_dtype == 'bfloat16': pytest.xfail("bfloat16 is not supported in the interpreter") else: + if not is_hip() and (M < 16 or N < 16 or K < 16): + pytest.skip("small dots are supported only on HIP at the moment") if is_cuda(): capability = torch.cuda.get_device_capability() @@ -3749,7 +3757,14 @@ def make_finite(x, dtype): for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + # Large block sizes - [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')]) + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + + # Small block sizes + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 8] + for num_warps in [1, 2, 4] + for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] + for M, N, K in [(32, 32, 32)] + for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests @@ -3762,6 +3777,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: input_precision = "tf32" if (is_cuda() or is_xpu()) and in_dtype_str == 'float32' else "ieee" + if BLOCK_M < 16 or BLOCK_N < 16: + pytest.skip("small dots are supported only on HIP at the moment") if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": if not is_interpreter() and torch.cuda.is_available( @@ -3822,6 +3839,10 @@ def kernel( if in_dtype_str == 'int8': out = numpy_random((B, M, N), dtype_str='int32', rs=rs) else: + if is_hip() and (BLOCK_M < 16 or BLOCK_N < 16) and out_dtype_str == 'float16': + # float16 accumulator in FMA dot loose precision too fast + x *= 0.1 + y *= 0.1 out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) x_tri = to_triton(x, device=device) diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir index 260dddb954..76fbe584cb 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -18,3 +18,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ tt.return } } + +// ----- + +// Verify that we use FMA when the N dimension is too small for any mma. +// CHECK-NOT: #ttg.amd_mfma +// CHECK-LABEL: small_n_size +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 64], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @small_n_size( + %a: tensor<4x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) + -> tensor<4x128xf32, #blocked> { + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<4x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<4x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<4x128xf32, #blocked> + tt.return %result : tensor<4x128xf32, #blocked> + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 77a1233dbb..0d5eae5c80 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -13,17 +13,30 @@ def min_dot_size(target: GPUTarget): + + def is_fma_supported(lhsType, rhsType): + return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) + + def get_gfx94_limits(lhsType, rhsType): + if is_fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # CDNA 3.0 supports k==8 in all mfma variants except for int8 + # (where the smallest `k` supported is 16) + return (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else (16, 16, 8) + + def get_gfx9_limits(lhsType, rhsType): + if is_fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # CDNA 2.0 always supports `k==8` + return (16, 16, 8) + arch_str = target.arch - # CDNA 3.0 supports k==8 in all mfma variants except for int8 - # (where the smallest `k` supported is 16) if "gfx94" in arch_str: - return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else ( - 16, 16, 8) - # CDNA 2.0 always supports `k==8` + return get_gfx94_limits if "gfx9" in arch_str: - return lambda lhsType, rhsType: (16, 16, 8) - # Other architectures will only support 16,16,16 - return lambda lhsType, rhsType: (16, 16, 16) + return get_gfx9_limits + # gfx11 and gfx12 architectures will only support 16,16,16 with wmma instructions + return lambda lhsType, rhsType: (1, 1, 1) if is_fma_supported(lhsType.scalar, rhsType.scalar) else (16, 16, 16) @dataclass(frozen=True) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index d4a6eb09fd..89de36cad5 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -350,8 +350,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); SmallVector sharedOrder; int rank = order.size(); - // TODO rework this when shared -> dotOp conversions support arbitrary - // shared memory ordering + // TODO rework this when shared -> dotOperand conversions support + // arbitrary shared memory ordering if (rank == 3) { // Move the batch dimension (dim #0) to be the last so that it will be // the slowest varying dimension. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 4b82755877..b042e38f6f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -797,23 +797,6 @@ MemDescType getExpandedDesc(MemDescType descTy) { return expandedDesc; } -SharedMemoryObject -getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, - SharedMemoryObject smemObj, - ArrayRef shape) { - auto strides = smemObj.getStrides(); - auto offsets = smemObj.getOffsets(); - auto rank = strides.size(); - if (rank == 3) - return smemObj; - auto expandedStrides = insertValue(strides, 0, i32_val(shape[0] * shape[1])); - auto expandedOffsets = insertValue(offsets, 0, i32_val(0)); - auto expandedSmemObj = - SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), - expandedStrides, expandedOffsets); - return expandedSmemObj; -} - namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding,