diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index eb83e55086a1..3ec77a416dc2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -275,6 +275,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // data in different CTAs and we know we're not in case 4. LinearLayout conversion = srcLayout->invertAndCompose(*dstLayout); + LinearLayout invertConversion = dstLayout->invertAndCompose(*srcLayout); + int numLanes = conversion.getInDimSize(str_attr("lane")); int numWarps = conversion.getInDimSize(str_attr("warp")); int numBlocks = conversion.getInDimSize(str_attr("block")); @@ -289,7 +291,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // stronger than this, checking also that the choice of lane/warp/block does // not affect the permutation of registers. If we allow different // lane/warp/blocks to have different permutations, we can generalize this. - if (std::optional c = conversion.divideRight( + if (std::optional c = invertConversion.divideRight( LinearLayout::identity1D(numLanes, kLane, kLane) * LinearLayout::identity1D(numWarps, kWarp, kWarp) * LinearLayout::identity1D(numBlocks, kBlock, kBlock)); @@ -323,10 +325,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion ArrayRef{kRegister}); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals(conversion.getOutDimSize(kRegister)); + SmallVector outVals(conversion.getInDimSize(kRegister)); for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { - auto dstIdx = conversion.apply({{kRegister, i}}); - outVals[dstIdx.begin()->second] = inVals[i]; + auto srcIdx = conversion.apply({{kRegister, i}}); + outVals[i] = inVals[srcIdx.begin()->second]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); @@ -353,8 +355,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); - assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - // TODO(jlebar): For now we handle only blocked/slice -> blocked/slice // conversions. Once we have ldmatrix support in // load/storeDistributedToShared, we can remove this constraint. @@ -372,6 +372,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return failure(); } + assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + SmallVector inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); assert(!inVals.empty());