Skip to content

Commit

Permalink
[BACKEND] Enable slice layout in fast/slow path reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
zahimoud committed Apr 29, 2023
1 parent 65fb36e commit bd36694
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 12 deletions.
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
SmallVector<unsigned> getSizePerThread(Attribute layout);

SmallVector<unsigned> getContigPerThread(Attribute layout);

SmallVector<unsigned> getThreadsPerWarpWithUniqueData(Attribute layout);
SmallVector<unsigned> getWarpsPerCTAWithUniqueData(Attribute layout);
SmallVector<unsigned> getUniqueContigPerThread(Type type);

SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
Expand Down
17 changes: 11 additions & 6 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@ bool ReduceOpHelper::isFastReduction() {
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
return std::min(
srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTAWithUniqueData(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
return std::min(
srcReduceDimSize,
triton::gpu::getThreadsPerWarpWithUniqueData(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout)[axis] *
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout)[axis];
}

SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
Expand Down Expand Up @@ -88,6 +90,9 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
}
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return true;
}
return false;
}

Expand Down
14 changes: 9 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ struct ReduceOpConversion
// to map every `axisSizePerThread` to 1 value in smem as:
// writeIdx[axis] = index[axis] / axisSizePerThread
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
if (mmaLayout && mmaLayout.isAmpere()) {
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (!mmaLayout.isAmpere()) {
llvm::report_fatal_error("Unsupported layout");
}
if (axis == 0) {
// Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8
// rows in smem would correspond to a warp. The mapping
Expand All @@ -113,8 +114,11 @@ struct ReduceOpConversion
// Same as BlockedEncodingAttr case
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
}
if (mmaLayout && !mmaLayout.isAmpere()) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints,
axis);
} else {
llvm::report_fatal_error("Unsupported layout");
}
}
Expand Down
60 changes: 60 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,37 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (mmaLayout.isAmpere())
return {8, 4};
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentThreadsPerWarp = getThreadsPerWarp(parent);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < threadsPerWarp.size(); i++)
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
return threadsPerWarp;
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}

SmallVector<unsigned> getThreadsPerWarpWithUniqueData(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return {4, 8};
if (mmaLayout.isAmpere())
return {8, 4};
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentThreadsPerWarp = getThreadsPerWarpWithUniqueData(parent);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
return threadsPerWarp;
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
Expand All @@ -94,6 +125,35 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
return warpsPerCTA;
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}

SmallVector<unsigned> getWarpsPerCTAWithUniqueData(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
blockedLayout.getWarpsPerCTA().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentWarpsPerCTA = getWarpsPerCTAWithUniqueData(parent);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
return warpsPerCTA;
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}
Expand Down

0 comments on commit bd36694

Please sign in to comment.