Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren committed Aug 6, 2024
1 parent 194a00f commit 232574c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// data in different CTAs and we know we're not in case 4.
LinearLayout conversion = srcLayout->invertAndCompose(*dstLayout);

LinearLayout invertConversion = dstLayout->invertAndCompose(*srcLayout);

int numLanes = conversion.getInDimSize(str_attr("lane"));
int numWarps = conversion.getInDimSize(str_attr("warp"));
int numBlocks = conversion.getInDimSize(str_attr("block"));
Expand All @@ -289,7 +291,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// stronger than this, checking also that the choice of lane/warp/block does
// not affect the permutation of registers. If we allow different
// lane/warp/blocks to have different permutations, we can generalize this.
if (std::optional<LinearLayout> c = conversion.divideRight(
if (std::optional<LinearLayout> c = invertConversion.divideRight(
LinearLayout::identity1D(numLanes, kLane, kLane) *
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
Expand Down Expand Up @@ -323,10 +325,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
ArrayRef{kRegister});

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals(conversion.getOutDimSize(kRegister));
SmallVector<Value> outVals(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto dstIdx = conversion.apply({{kRegister, i}});
outVals[dstIdx.begin()->second] = inVals[i];
auto srcIdx = conversion.apply({{kRegister, i}});
outVals[i] = inVals[srcIdx.begin()->second];
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
Expand All @@ -353,8 +355,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();

assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

// TODO(jlebar): For now we handle only blocked/slice -> blocked/slice
// conversions. Once we have ldmatrix support in
// load/storeDistributedToShared, we can remove this constraint.
Expand All @@ -372,6 +372,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return failure();
}

assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

SmallVector<Value> inVals =
unpackLLElements(loc, adaptor.getSrc(), rewriter);
assert(!inVals.empty());
Expand Down

0 comments on commit 232574c

Please sign in to comment.