Skip to content

Commit

Permalink
Zahi/slice reduce rebased (#1594)
Browse files Browse the repository at this point in the history
[BACKEND] Enable slice layout support for reduce op
  • Loading branch information
zahimoud authored May 2, 2023
1 parent 26d80f0 commit 3449a9d
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 16 deletions.
27 changes: 27 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,37 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);

SmallVector<unsigned> getSizePerThread(Attribute layout);

// Returns the number of contiguous elements that each thread
// has access to, on each dimension of the tensor. E.g.
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
// regardless of the shape of the tensor.
SmallVector<unsigned> getContigPerThread(Attribute layout);

// Returns the number of non-replicated contiguous elements that each thread
// has access to, on each dimension of the tensor. For a blocked layout
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
SmallVector<unsigned> getUniqueContigPerThread(Type type);

// Returns the number of threads per warp that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17
// have access to the full tensor, whereas the other threads have access to
// replicated elements, so this function returns [2, 2].
SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape);

// Returns the number of warps per CTA that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
// returns [1, 1], since the first warp has access to the full tensor, whereas
// the other warps have access to replicated elements.
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

SmallVector<unsigned> getThreadsPerCTA(Attribute layout);

SmallVector<unsigned>
Expand Down
15 changes: 11 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
triton::gpu::getWarpsPerCTAWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}

unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
triton::gpu::getThreadsPerWarpWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
auto srcShape = getSrcShape();
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
srcShape)[axis] *
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
}

SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
Expand Down Expand Up @@ -88,6 +92,9 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
}
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return true;
}
return false;
}

Expand Down
19 changes: 14 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ struct ReduceOpConversion
Attribute layout, SmallVector<Value> &index,
SmallVector<Value> &writeIdx,
std::map<int, Value> &ints, unsigned axis) const {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
assert(dim != axis && "Reduction axis cannot be sliced");
auto parentLayout = sliceLayout.getParent();
getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints,
axis);
return;
}

writeIdx = index;
auto sizePerThread = triton::gpu::getSizePerThread(layout);
Value axisSizePerThread = ints[sizePerThread[axis]];
Expand All @@ -100,9 +109,10 @@ struct ReduceOpConversion
// to map every `axisSizePerThread` to 1 value in smem as:
// writeIdx[axis] = index[axis] / axisSizePerThread
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
if (mmaLayout && mmaLayout.isAmpere()) {
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (!mmaLayout.isAmpere()) {
llvm::report_fatal_error("Unsupported layout");
}
if (axis == 0) {
// Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8
// rows in smem would correspond to a warp. The mapping
Expand All @@ -113,8 +123,7 @@ struct ReduceOpConversion
// Same as BlockedEncodingAttr case
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
}
if (mmaLayout && !mmaLayout.isAmpere()) {
} else {
llvm::report_fatal_error("Unsupported layout");
}
}
Expand Down
78 changes: 71 additions & 7 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,41 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (mmaLayout.isAmpere())
return {8, 4};
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentThreadsPerWarp = getThreadsPerWarp(parent);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < threadsPerWarp.size(); i++)
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
return threadsPerWarp;
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}

SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(tensorShape);
auto parentThreadsPerWarp =
getThreadsPerWarpWithUniqueData(parentLayout, parentShape);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
return threadsPerWarp;
}
auto threadsPerWarp = getThreadsPerWarp(layout);
assert(threadsPerWarp.size() == tensorShape.size() &&
"layout and tensor shape must have the same rank");
for (unsigned i = 0; i < threadsPerWarp.size(); i++) {
threadsPerWarp[i] = std::min<unsigned>(threadsPerWarp[i], tensorShape[i]);
}

return threadsPerWarp;
}

SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
Expand All @@ -94,10 +125,43 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
return warpsPerCTA;
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}

SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(tensorShape);
auto parentWarpsPerCTA =
getWarpsPerCTAWithUniqueData(parentLayout, parentShape);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
return warpsPerCTA;
}
auto warpsPerCTA = getWarpsPerCTA(layout);
assert(warpsPerCTA.size() == tensorShape.size() &&
"layout and tensor shape must have the same rank");
for (unsigned i = 0; i < warpsPerCTA.size(); i++) {
auto sizePerWarp =
getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i];
auto maxWarpsPerDim = ceil<unsigned>(tensorShape[i], sizePerWarp);
warpsPerCTA[i] = std::min<unsigned>(warpsPerCTA[i], maxWarpsPerDim);
}

return warpsPerCTA;
}

SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
Expand Down Expand Up @@ -189,7 +253,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersionMajor() == 2) {
if (mmaLayout.isAmpere()) {
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
4 * mmaLayout.getWarpsPerCTA()[1]};
} else
Expand Down Expand Up @@ -1074,9 +1138,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
// Ensure that the new insert_slice op is placed in the same place as
// the old insert_slice op. Otherwise, the new insert_slice op may be
// placed after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
Expand Down Expand Up @@ -1104,9 +1168,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
// Ensure that the new extract_slice op is placed in the same place as
// the old extract_slice op. Otherwise, the new extract_slice op may be
// placed after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
Expand Down
64 changes: 64 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,70 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)


layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]


@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
@pytest.mark.parametrize("src_layout", layouts)
def test_reduce_2d(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')

z = np.zeros((1,)).astype('int32')

x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)

pgm = kernel[(1, 1, 1)](x_tri, z_tri)

z_ref = np.sum(x)

np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


def test_generic_reduction(device='cuda'):

@triton.jit
Expand Down

0 comments on commit 3449a9d

Please sign in to comment.