Skip to content

Commit

Permalink
Make logic more general
Browse files Browse the repository at this point in the history
  • Loading branch information
ggengnv committed Nov 21, 2024
1 parent 1f42b49 commit 9a9fcb0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 39 deletions.
4 changes: 4 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,10 @@ SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset);

// TODO document
std::optional<LinearLayout> getRegToSharedLayout(MLIRContext* ctx,
ArrayRef<int64_t> shape, Attribute srcEnc, Attribute dstEnc, int elemBitWidth);

// Emits IR to load data from shared memory into registers, or to store data
// from registers into shared memory.
//
Expand Down
64 changes: 38 additions & 26 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,36 +158,25 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
return ret;
}

bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

auto shape = registerTy.getShape();
int rank = shape.size();

std::optional<LinearLayout> getRegToSharedLayout(MLIRContext* ctx,
ArrayRef<int64_t> shape, Attribute srcEnc, Attribute dstEnc, int elemBitWidth) {
StringAttr kBlock = str_attr("block");
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
int rank = shape.size();

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
triton::gpu::toLinearLayout(shape, srcEnc);
std::optional<LinearLayout> sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
shape, dstEnc, elemBitWidth);
if (!regLayout.has_value() || !sharedLayout.has_value()) {
return false;
return std::nullopt;
}
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
auto sharedOrder = triton::gpu::getOrder(dstEnc);

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy =
cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
cast<triton::gpu::SharedEncodingAttr>(dstEnc);
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
Expand All @@ -202,13 +191,35 @@ bool emitTransferBetweenRegistersAndShared(

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout);
return regLayout->invertAndCompose(*sharedLayout);
}

bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

auto shape = registerTy.getShape();
int rank = shape.size();

StringAttr kBlock = str_attr("block");
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

auto regToSharedLayout = getRegToSharedLayout(ctx, shape, registerTy.getEncoding(),
sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
if (!regToSharedLayout.has_value())
return false;

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
Expand All @@ -234,15 +245,15 @@ bool emitTransferBetweenRegistersAndShared(
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems =
std::min(regToSharedLayout.getNumConsecutiveInOut(),
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

int numElems = regToSharedLayout.getInDimSize(kRegister);
int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
auto ptrTy = shmemBase.getType();
Value zero = i32_val(0);
Expand All @@ -253,14 +264,15 @@ bool emitTransferBetweenRegistersAndShared(
// we drop_end to drop block, which we know from above will be 0.
auto multiDimShmemOffset =
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
applyLinearLayout(loc, rewriter, regToSharedLayout,
applyLinearLayout(loc, rewriter, *regToSharedLayout,
{{kRegister, i32_val(i * vecElems)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, zero}}))));

// Reorder strides according to `order`. This way they match the
// multi-dimensional offsets in regToSharedLayout.
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
applyPermutation(shmemStrides, sharedOrder));
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);
Expand Down
36 changes: 23 additions & 13 deletions lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

#include <memory>

Expand Down Expand Up @@ -44,29 +45,38 @@ struct ClipAsyncCopySizePerThread
Value src = copyOp.getSrc();
Value mask = copyOp.getMask();
Value other = copyOp.getOther();

auto inputTy = cast<RankedTensorType>(src.getType());
auto blockEnc = cast<BlockedEncodingAttr>(inputTy.getEncoding());
auto resultTy = cast<tt::MemDescType>(copyOp.getResult().getType());
auto sharedEnc = cast<SharedEncodingAttr>(resultTy.getEncoding());
auto srcTy = cast<RankedTensorType>(src.getType());
auto blockEnc = cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dstTy = cast<tt::MemDescType>(copyOp.getResult().getType());
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding());
auto sharedVec = sharedEnc.getVec();

// clip each dim of sizePerThread by its respective dim in vec
SmallVector<unsigned> newSizePerThread;
llvm::transform(blockEnc.getSizePerThread(),
std::back_inserter(newSizePerThread),
[&](auto size) { return std::min(size, sharedVec); });
// obtain max contiguous copy size
// Note this can be further optimized, as copyContigSize can be even
// smaller when lowering, depending on contiguity and mask alignment
// (see AsyncCopyGlobalToLocalOpConversion)
auto elemBitWidth = dstTy.getElementTypeBitWidth();
auto regToSharedLayout = getRegToSharedLayout(rewriter.getContext(),
srcTy.getShape(), blockEnc, sharedEnc, elemBitWidth);
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut();

// obtain block sizePerThread along contig dim
auto sizePerThread = blockEnc.getSizePerThread();
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]];

if (newSizePerThread == blockEnc.getSizePerThread())
if (blockContigSize <= copyContigSize)
return rewriter.notifyMatchFailure(copyOp,
"at least one dimension of blocked sizePerThread must be greater than shared vec");
"blocked sizePerThread along contiguous dim must be greater than the "
"max contiguous copy size ");

sizePerThread[blockEnc.getOrder()[0]] = copyContigSize;

// obtain new blockedEnc based on clipped sizePerThread
auto mod = copyOp->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto newBlockEnc = BlockedEncodingAttr::get(
copyOp.getContext(), inputTy.getShape(), newSizePerThread,
copyOp.getContext(), srcTy.getShape(), sizePerThread,
blockEnc.getOrder(), numWarps, threadsPerWarp,
blockEnc.getCTALayout());

Expand Down

0 comments on commit 9a9fcb0

Please sign in to comment.