Skip to content

Commit

Permalink
Merge branch 'main' into keren/triton-gpu-to-ttg
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren committed Nov 27, 2024
2 parents b35fe81 + 9e508a4 commit d3c45cb
Show file tree
Hide file tree
Showing 34 changed files with 333 additions and 198 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ python/triton/language/extra
# Proton
python/triton/profiler

# Pytest
pytest.ini

# Instrumentation
python/triton/instrumentation

Expand Down
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ if(NOT WIN32)
endif()

if(TRITON_BUILD_UT)
# This is an aggregate target for all unit tests.
add_custom_target(TritonUnitTests)
set_target_properties(TritonUnitTests PROPERTIES FOLDER "Triton/Tests")
include(AddTritonUnitTest)
endif()

Expand Down Expand Up @@ -340,4 +343,10 @@ add_subdirectory(test)

if(TRITON_BUILD_UT)
add_subdirectory(unittest)
# This target runs all the unit tests.
add_custom_target(check-triton-unit-tests
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
DEPENDS TritonUnitTests
USES_TERMINAL
)
endif()
3 changes: 3 additions & 0 deletions cmake/AddTritonUnitTest.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ function(add_triton_ut)
# laptop. I think the issue may be that the very first time you run a program
# it's a bit slow.
gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60)

# Add the unit test to the top-level unit test target.
add_dependencies(TritonUnitTests ${__NAME})
endfunction()
7 changes: 7 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Conversion/MLIRTypes.h"

namespace mlir::triton {

class TargetInfoBase {
public:
virtual bool supportMaximumMinimum() const = 0;
Expand Down Expand Up @@ -37,6 +38,12 @@ class TargetInfoBase {
pred);
}

virtual bool canUseStMatrix(RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) const = 0;

virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
Value ptr, Value val) const = 0;

Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
// bit width of the tensor in the future to support more flexible tensor
// encodings
std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == rank-1) ? k : m;
int mmaStride = (order[0] == rank-1) ? m : k;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand All @@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == rank-1) ? n : k;
int mmaStride = (order[0] == rank-1) ? k : n;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
(elementTypeSize == 16 || elementTypeSize == 8) &&
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
return ans;
}

Expand Down
56 changes: 25 additions & 31 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,24 +380,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return !useLegacyMMAConversion;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto parent = dotOperand.getParent();
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
return false;
}
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (nvidiaMma.isAmpere()) {
return true;
}
}
if (isa<AMDMfmaEncodingAttr>(parent)) {
return true;
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(
dotOperand.getParent())) {
return !useLegacyMMAConversion;
}
return false;
}
if (isa<BlockedEncodingAttr>(layout)) {
return true;
}
if (isa<LinearEncodingAttr>(layout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
return true;
}
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
Expand All @@ -408,6 +397,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
return failure();
}
// FIXME [Dot LL] Remove this once we implement this trick in LLs
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
return failure();
}

assert(cvtNeedsSharedMemory(srcTy, dstTy));

Expand Down Expand Up @@ -498,34 +491,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// don't need to avoid duplicate writes.
// Input dims: [reg, lane, warp]
// Output dims: [offset, iteration]
std::optional<LinearLayout> shmemStoreLayout =
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
bool isStMatrix = shmemStoreLayout.has_value();
if (!isStMatrix) {
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
}
assert(shmemStoreLayout.has_value());
bool isStMatrix = targetInfo.canUseStMatrix(
op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
LinearLayout shmemStoreLayout =
isStMatrix ? chooseStMatrixLayout(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0)
: srcLayout.invertAndCompose(sharedLayout);

const int shmemAllocatedNumElems =
getNumScratchElements(scratchConfig.paddedRepShape);
assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems);
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);

// Layout for the load from shmem to registers.
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);

// Check that the `register` fully determines the `iteration`. That is,
// each thread does exactly the same reads and writes to shmem on each
// iteration, just with different input/output registers.
assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock},
{kIteration}));
assert(
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
assert(
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));

// iteration -> registers
SmallVector<SmallVector<int>> inRegsForIter =
collectRegsForIter(ctx, *shmemStoreLayout);
collectRegsForIter(ctx, shmemStoreLayout);
SmallVector<SmallVector<int>> outRegsForIter =
collectRegsForIter(ctx, shmemLoadLayout);

Expand Down Expand Up @@ -582,7 +576,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return vecAddr;
};

auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout,
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
{{kRegister, i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
Expand All @@ -605,11 +599,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

// When using `stmatrix`, we can store `inVec` elements even if they are
// not contiguous
auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut()
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
: scratchConfig.inVec;
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
auto inRegSlice = inRegs[j];
Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice);
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
SmallVector<Value> inValsVec;
for (int k = 0; k < inVec; k++)
inValsVec.push_back(inVals[inRegSlice + k]);
Expand Down
44 changes: 29 additions & 15 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
static bool isSupportedDotOpLayout(MemDescType srcTy,
RankedTensorType dstTy) {
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
auto bitwidth = dstTy.getElementTypeBitWidth();
auto rank = dstTy.getRank();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
auto vecWidth = 32 / bitwidth;
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
auto needTrans = kOrder != srcLayout.getOrder()[0];
auto canUseLdmatrix =
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
if (mma.isHopper()) {
// I think we should be able to remove this condition, but it's here
// as the legacy ldmatrix path does not support it
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
}
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though
canUseLdmatrix &=
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
// To be removed in https://github.com/triton-lang/triton/pull/5154
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
return legacyLoweringIsBuggy && mma.isAmpere();
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
return (mma.isHopper() && !canUseLdmatrix) ||
(mma.isAmpere() && legacyLoweringIsBuggy);
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
Expand All @@ -162,12 +178,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(srcTy, dstTy)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -206,7 +220,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
return getResult();
}

// Eliminate splat constant transpose ops.
if (auto attr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSrc()))
return attr.reshape(getType());

return {};
}

Expand Down
58 changes: 8 additions & 50 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
llvm::report_fatal_error("Illegal shared layout");
}

int vec = 8 * 16 / elemBitWidth;
if (vec != shared.getVec()) {
llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec
<< ": " << shared << "\n";
llvm::report_fatal_error("Illegal shared layout");
}
int vec = shared.getVec();

StringAttr colDimName = outDimNames[colDim];
StringAttr rowDimName = outDimNames[rowDim];
Expand Down Expand Up @@ -858,40 +853,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
}

namespace {

// TODO (Keren): Currently, we have more restrictions than necessary when using
// stmatrix. These restrictions are retained from legacy code, and we could
// relax some of them in the future.
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
auto mmaLayout =
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
if (!mmaLayout || !mmaLayout.isHopper())
return false;
if (isa<PointerType>(tensorTy.getElementType()))
return false;
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
return false;
if (order[0] != 1)
return false;

auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
if (tensorShapePerCTA.size() != 2)
return false;
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
if (numIterations > 1)
return false;
if (paddedRepShape[1] % 8 != 0)
return false;
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
swizzleByteSize != 128)
return false;
return true;
}

std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
LinearLayout chooseStMatrixLayoutLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
Expand Down Expand Up @@ -962,7 +924,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
LinearLayout chooseStMatrixLayoutNoLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
StringAttr kReg = S("register");
Expand Down Expand Up @@ -1002,15 +964,11 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(

} // anonymous namespace

std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
swizzleByteSize))
return std::nullopt;

LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) {
if (swizzleByteSize == 0)
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
paddedRepShape, order);
Expand Down
Loading

0 comments on commit d3c45cb

Please sign in to comment.