diff --git a/test/Conversion/intel/intel-allocate-shared-memory.mlir b/test/Conversion/intel/intel-allocate-shared-memory.mlir index 5fad77531e..81cbcf4a31 100644 --- a/test/Conversion/intel/intel-allocate-shared-memory.mlir +++ b/test/Conversion/intel/intel-allocate-shared-memory.mlir @@ -18,6 +18,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> + +// Check no scratch memory is allocated for sub-group shuffle-like layout conversions. + +// CHECK-LABEL: module attributes +// CHECK-SAME: triton_gpu.shared = 0 : i32 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK: tt.func @test_sub_group_shuffle + // CHECK-NOT: llvm.ptr<3> + tt.func @test_sub_group_shuffle(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { + %0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> diff --git a/test/Conversion/intel/sub-group-shuffle.mlir b/test/Conversion/intel/sub-group-shuffle.mlir index 8bcd1b57dc..0d033f17a4 100644 --- a/test/Conversion/intel/sub-group-shuffle.mlir +++ b/test/Conversion/intel/sub-group-shuffle.mlir @@ -360,3 +360,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : tt.return %0 : tensor<128xi32, #sliced1> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> + +// Case of more than one contiguous element per work-item. + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_contiguous( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16)>) + tt.func @test_contiguous(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { + // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16, f16)> + // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f16, f16)> + // COM: Check the shuffles are "coalesced" + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]] + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]] + %0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + } +} diff --git a/third_party/intel/lib/Analysis/Utility.cpp b/third_party/intel/lib/Analysis/Utility.cpp index 13b8607192..cba7f77398 100644 --- a/third_party/intel/lib/Analysis/Utility.cpp +++ b/third_party/intel/lib/Analysis/Utility.cpp @@ -71,6 +71,29 @@ buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) { return bases; } +// Return a vector such as: +// [[1, 0], [2, 0], [4, 0], ..., [registerSize / laneSize, 0], [0, 1], ..., +// [0, laneSize/2]] +// i.e., mapping registers to registers till registerSize / laneSize (all +// contiguous registers) and then to lanes. +std::vector> +buildContiguousSubGroupShuffleRegisterBases(int32_t registerSize, + int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + int i = 1; + for (; i < registerSize / laneSize; i *= 2) { + curr[0] = i; + bases.push_back(curr); + } + curr[0] = 0; + for (int32_t val = 1; i < registerSize; i *= 2, val *= 2) { + curr[1] = val; + bases.push_back(curr); + } + return bases; +} + // Return a vector such as: // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]], // i.e., mapping lanes to registers. @@ -138,9 +161,10 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { // ... // - register=2**i -> (0, 2**i) // ... - // - register=M -> (0, 2**M) + // - register=M -> (0, 2**(M-1)) + // - register=M+1 -> (1, 0) // ... - // - register=2**k -> (2**(k-M), 0) + // - register=2**k -> (2**(K-M), 0) // ... // - register=2**N -> (2**(N-M), 0) // - lane=1 -> (0, 0) @@ -148,15 +172,38 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { // - lane=2**j -> (0, 0) // ... // lane=2**M -> (0, 0) + // where out dims are: [register (size 2**N), lane (size 2**M)] + // + // With N >= M. + // + // Or, when the elements managed by a given work-item are in contiguous + // positions: + // - register=1 -> (1, 0) + // ... + // - register=2**i -> (2**i, 0) + // ... + // - register=M -> (2**(N - M), 0) + // ... + // - register=2**k -> (0, 1) + // ... + // - register=2**N -> (0, 2**(M-1)) + // - lane=1 -> (0, 0) + // ... + // - lane=2**j -> (0, 0) + // ... + // lane=2**M -> (0, 0) // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))] // // With N >= M. int32_t registerInDimSize = conversion->getInDimSize(kRegister); int32_t laneOutDimSize = conversion->getOutDimSize(kLane); return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) && - conversion->getBases().lookup(kRegister) == - buildSubGroupShuffleRegisterBases(registerInDimSize, - laneOutDimSize); + (conversion->getBases().lookup(kRegister) == + buildSubGroupShuffleRegisterBases(registerInDimSize, + laneOutDimSize) || + conversion->getBases().lookup(kRegister) == + buildContiguousSubGroupShuffleRegisterBases(registerInDimSize, + laneOutDimSize)); } bool isValidElementTypeForSubGroupTranspose(Type type) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 83f9042912..c3a5b8da74 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -560,6 +560,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } + int getNumContiguousRowsForShuffle(const LinearLayout &srcLayout, + const LinearLayout &dstLayout) const { + MLIRContext *ctx = getContext(); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + LinearLayout comp = + *dstLayout.invertAndCompose(srcLayout).quotient({kWarp, kBlock}); + // Basic case: the number of contiguous rows is 1. + if (comp.getBasis(kRegister, 0)[1] == 1) + return 1; + // In other case, we only allow all threads handled by a single element to + // be contiguous, so we can simply: + return comp.getOutDimSize(kRegister); + } + void performSubGroupShuffle(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -605,8 +623,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion }); }); - SmallVector outVals = - performSubGroupShuffle(loc, inVals, subGroupSize, rewriter); + SmallVector outVals = performSubGroupShuffle( + loc, inVals, subGroupSize, rewriter, + getNumContiguousRowsForShuffle(srcLayout, dstLayout)); // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR // upstream level. We are not enabling support for all types here as that @@ -636,19 +655,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion rewriter.replaceOp(op, result); } - SmallVector - performSubGroupShuffle(Location loc, ArrayRef inVals, - int32_t subGroupSize, - ConversionPatternRewriter &rewriter) const { + SmallVector performSubGroupShuffle(Location loc, + ArrayRef inVals, + int32_t subGroupSize, + ConversionPatternRewriter &rewriter, + int numContiguousRows) const { SmallVector res; Value width = i32_val(subGroupSize); - for (Value val : inVals) { - for (int32_t i = 0; i < subGroupSize; ++i) - res.push_back( - rewriter - .create(loc, val, i32_val(i), width, - mlir::gpu::ShuffleMode::IDX) - .getShuffleResult()); + // A work-item may handle more than one element. There are two cases we + // support: + if (numContiguousRows == 1) { + // 1. Elements held by a work-item are strided rows in the abstract slice + // matrix: Output element `i` will take the `i / 16`th value from the `i % + // 16`th thread. + for (Value val : inVals) { + for (int32_t i = 0; i < subGroupSize; ++i) { + res.push_back( + rewriter + .create(loc, val, i32_val(i), width, + mlir::gpu::ShuffleMode::IDX) + .getShuffleResult()); + } + } + } else { + // 2. Elements held by a work-item are contiguous rows in the abstract + // slice matrix: Output element `i` will take the `i % 16`th value from + // the `i / 16`th thread. + for (int32_t i = 0; i < subGroupSize; ++i) { + for (Value val : inVals) { + res.push_back( + rewriter + .create(loc, val, i32_val(i), width, + mlir::gpu::ShuffleMode::IDX) + .getShuffleResult()); + } + } } return res; }