Skip to content

Commit

Permalink
[XPU][TritonIntelGPUToLLVM] Add support for more shuffle kinds (#2799)
Browse files Browse the repository at this point in the history
Add support for layout conversion shuffles in which rows managed by a
single thread are contiguous in the output matrix.

Step 2/2 to
#2749

---------

Signed-off-by: victor-eds <victor.perez@codeplay.com>
  • Loading branch information
victor-eds authored Nov 26, 2024
1 parent 67ea90d commit 2a72ba2
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 18 deletions.
18 changes: 18 additions & 0 deletions test/Conversion/intel/intel-allocate-shared-memory.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// Check no scratch memory is allocated for sub-group shuffle-like layout conversions.

// CHECK-LABEL: module attributes
// CHECK-SAME: triton_gpu.shared = 0 : i32
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK: tt.func @test_sub_group_shuffle
// CHECK-NOT: llvm.ptr<3>
tt.func @test_sub_group_shuffle(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>

Expand Down
35 changes: 35 additions & 0 deletions test/Conversion/intel/sub-group-shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
tt.return %0 : tensor<128xi32, #sliced1>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// Case of more than one contiguous element per work-item.

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16)>)
tt.func @test_contiguous(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16, f16)>
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f16, f16)>
// COM: Check the shuffles are "coalesced"
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
}
57 changes: 52 additions & 5 deletions third_party/intel/lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
return bases;
}

// Return a vector such as:
// [[1, 0], [2, 0], [4, 0], ..., [registerSize / laneSize, 0], [0, 1], ...,
// [0, laneSize/2]]
// i.e., mapping registers to registers till registerSize / laneSize (all
// contiguous registers) and then to lanes.
std::vector<std::vector<int32_t>>
buildContiguousSubGroupShuffleRegisterBases(int32_t registerSize,
int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
int i = 1;
for (; i < registerSize / laneSize; i *= 2) {
curr[0] = i;
bases.push_back(curr);
}
curr[0] = 0;
for (int32_t val = 1; i < registerSize; i *= 2, val *= 2) {
curr[1] = val;
bases.push_back(curr);
}
return bases;
}

// Return a vector such as:
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
// i.e., mapping lanes to registers.
Expand Down Expand Up @@ -138,25 +161,49 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
// ...
// - register=2**i -> (0, 2**i)
// ...
// - register=M -> (0, 2**M)
// - register=M -> (0, 2**(M-1))
// - register=M+1 -> (1, 0)
// ...
// - register=2**k -> (2**(k-M), 0)
// - register=2**k -> (2**(K-M), 0)
// ...
// - register=2**N -> (2**(N-M), 0)
// - lane=1 -> (0, 0)
// ...
// - lane=2**j -> (0, 0)
// ...
// lane=2**M -> (0, 0)
// where out dims are: [register (size 2**N), lane (size 2**M)]
//
// With N >= M.
//
// Or, when the elements managed by a given work-item are in contiguous
// positions:
// - register=1 -> (1, 0)
// ...
// - register=2**i -> (2**i, 0)
// ...
// - register=M -> (2**(N - M), 0)
// ...
// - register=2**k -> (0, 1)
// ...
// - register=2**N -> (0, 2**(M-1))
// - lane=1 -> (0, 0)
// ...
// - lane=2**j -> (0, 0)
// ...
// lane=2**M -> (0, 0)
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
//
// With N >= M.
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
conversion->getBases().lookup(kRegister) ==
buildSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize);
(conversion->getBases().lookup(kRegister) ==
buildSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize) ||
conversion->getBases().lookup(kRegister) ==
buildContiguousSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize));
}

bool isValidElementTypeForSubGroupTranspose(Type type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return success();
}

int getNumContiguousRowsForShuffle(const LinearLayout &srcLayout,
const LinearLayout &dstLayout) const {
MLIRContext *ctx = getContext();

StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");
LinearLayout comp =
*dstLayout.invertAndCompose(srcLayout).quotient({kWarp, kBlock});
// Basic case: the number of contiguous rows is 1.
if (comp.getBasis(kRegister, 0)[1] == 1)
return 1;
// In other case, we only allow all threads handled by a single element to
// be contiguous, so we can simply:
return comp.getOutDimSize(kRegister);
}

void performSubGroupShuffle(ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -605,8 +623,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
});
});

SmallVector<Value> outVals =
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);
SmallVector<Value> outVals = performSubGroupShuffle(
loc, inVals, subGroupSize, rewriter,
getNumContiguousRowsForShuffle(srcLayout, dstLayout));

// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
// upstream level. We are not enabling support for all types here as that
Expand Down Expand Up @@ -636,19 +655,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
rewriter.replaceOp(op, result);
}

SmallVector<Value>
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
int32_t subGroupSize,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> performSubGroupShuffle(Location loc,
ArrayRef<Value> inVals,
int32_t subGroupSize,
ConversionPatternRewriter &rewriter,
int numContiguousRows) const {
SmallVector<Value> res;
Value width = i32_val(subGroupSize);
for (Value val : inVals) {
for (int32_t i = 0; i < subGroupSize; ++i)
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
// A work-item may handle more than one element. There are two cases we
// support:
if (numContiguousRows == 1) {
// 1. Elements held by a work-item are strided rows in the abstract slice
// matrix: Output element `i` will take the `i / 16`th value from the `i %
// 16`th thread.
for (Value val : inVals) {
for (int32_t i = 0; i < subGroupSize; ++i) {
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
}
}
} else {
// 2. Elements held by a work-item are contiguous rows in the abstract
// slice matrix: Output element `i` will take the `i % 16`th value from
// the `i / 16`th thread.
for (int32_t i = 0; i < subGroupSize; ++i) {
for (Value val : inVals) {
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
}
}
}
return res;
}
Expand Down

0 comments on commit 2a72ba2

Please sign in to comment.