Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update
  • Loading branch information
Jokeren committed Aug 9, 2024
1 parent 232574c commit dd60c90
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 196 deletions.
6 changes: 1 addition & 5 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// TODO(jlebar): Remove this function; it's subsumed by the linear-layout case
// in cvtNeedsSharedMemory.
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module);
/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
/// true.
using ShortcutFn = std::function<bool(RankedTensorType &, RankedTensorType &)>;
using ShortcutFn = std::function<bool(RankedTensorType, RankedTensorType)>;
template <typename TensorCoreEncodingAttr>
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
ShortcutFn shortcutFn);
Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);

assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) &&
assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() &&
!srcMmaLayout.isHopper()) &&
"mma -> mma layout conversion is only supported on Ampere");

// mma or dot layout does not have an order, so the order depends on the
Expand Down
30 changes: 8 additions & 22 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
Expand All @@ -543,21 +543,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
auto src = dyn_cast<NvidiaMmaEncodingAttr>(srcEncoding);
auto dst = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding);
if (!src || !dst)
return false;
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1;
}

bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand All @@ -567,10 +552,12 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
return false;
}
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto ans =
mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcTy.getEncoding()) &&
(elementTypeSize == 16 || elementTypeSize == 8);
auto parentTy = RankedTensorType::get(
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
}

Expand Down Expand Up @@ -605,8 +592,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {

// TODO(jlebar): Remove these special cases once they're fully subsumed by the
// linear-layout check above.
return !isMmaToMmaShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
return !isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}

Expand Down
63 changes: 48 additions & 15 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// data in different CTAs and we know we're not in case 4.
LinearLayout conversion = srcLayout->invertAndCompose(*dstLayout);

LinearLayout invertConversion = dstLayout->invertAndCompose(*srcLayout);

int numLanes = conversion.getInDimSize(str_attr("lane"));
int numWarps = conversion.getInDimSize(str_attr("warp"));
int numBlocks = conversion.getInDimSize(str_attr("block"));

StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");
Expand All @@ -291,12 +290,37 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// stronger than this, checking also that the choice of lane/warp/block does
// not affect the permutation of registers. If we allow different
// lane/warp/blocks to have different permutations, we can generalize this.
if (std::optional<LinearLayout> c = invertConversion.divideRight(

// There are two three possible cases
// 1. The `src_layout` has the same number of registers as the `dst_layout`.
// 2. The `src_layout` has fewer registers than the `dst_layout`.
// 3. The `src_layout` has more registers than the `dst_layout`.
// In the second case, we may generate a conversion that is not surjective
// because not all lanes are covered. Instead, we could use the inverse of
// the conversion, mapping from `dst_layout` to `src_layout`, which is
// surjective. This inverse layout indicates that multiple destination
// registers may come from the same source register.
//
if (std::optional<LinearLayout> srcToDst = conversion.divideRight(
LinearLayout::identity1D(numLanes, kLane, kLane) *
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
c.has_value()) {
return transferWithinThread(*c, op, adaptor, rewriter);
srcToDst.has_value()) {
auto inRegSize = srcLayout->getInDimSize(kRegister);
auto outRegSize = dstLayout->getInDimSize(kRegister);
if (inRegSize <= outRegSize) {
LinearLayout inverseConversion =
dstLayout->invertAndCompose(*srcLayout);
auto dstToSrc = inverseConversion.divideRight(
LinearLayout::identity1D(numLanes, kLane, kLane) *
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
return transferWithinThread(*dstToSrc, op, adaptor, rewriter,
/*srcToDst=*/false);
} else {
return transferWithinThread(*srcToDst, op, adaptor, rewriter,
/*srcToDst=*/true);
}
}

if (std::optional<LinearLayout> c = conversion.divideRight(
Expand All @@ -310,10 +334,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
adaptor, rewriter);
}

LogicalResult
transferWithinThread(const LinearLayout &conversion, ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
LogicalResult transferWithinThread(const LinearLayout &conversion,
ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
bool srcToDst) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
StringAttr kRegister = str_attr("register");
Expand All @@ -325,10 +349,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
ArrayRef{kRegister});

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto srcIdx = conversion.apply({{kRegister, i}});
outVals[i] = inVals[srcIdx.begin()->second];
SmallVector<Value> outVals;
if (srcToDst) {
outVals.resize(conversion.getOutDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto dstIdx = conversion.apply({{kRegister, i}});
outVals[dstIdx.begin()->second] = inVals[i];
}
} else {
outVals.resize(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto srcIdx = conversion.apply({{kRegister, i}});
outVals[i] = inVals[srcIdx.begin()->second];
}
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
Expand All @@ -355,8 +388,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();

// TODO(jlebar): For now we handle only blocked/slice -> blocked/slice
// conversions. Once we have ldmatrix support in
// TODO(jlebar): For now we handle only blocked/slice ->
// blocked/slice conversions. Once we have ldmatrix support in
// load/storeDistributedToShared, we can remove this constraint.
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (isa<BlockedEncodingAttr>(layout)) {
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ LogicalResult MakeRangeOp::verify() {

//-- ReduceOp --
static LogicalResult
inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy,
int axis, SmallVectorImpl<Type> &inferredReturnTypes) {
inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto retShape = argTy.getShape().vec();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
Expand Down
Loading

0 comments on commit dd60c90

Please sign in to comment.