Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WSLowering] consumer release on each thread instead of master thread #10

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 7 additions & 59 deletions lib/Dialect/TritonGPU/Transforms/WSLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,44 +149,8 @@ void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op,
Value bufferEmpty, int numCTAs) {
auto loc = op.getLoc();
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
Value _4 = builder.create<arith::ConstantIntOp>(loc, 4, 32);
Value _8 = builder.create<arith::ConstantIntOp>(loc, 8, 32);
Value _32 = builder.create<arith::ConstantIntOp>(loc, 32, 32);
Value _threadPerTask =
builder.create<arith::ConstantIntOp>(loc, THREADS_PER_TASK, 32);

// threadId = threadId % THREADS_PER_TASK
Value threadId = builder.create<arith::RemUIOp>(
loc, createThreadIdOp(builder, loc), _threadPerTask);
// k = threadId / 8
Value k = builder.create<arith::DivUIOp>(loc, threadId, _8);
// row = k / 4
Value row = builder.create<arith::DivUIOp>(loc, k, _4);
// col = k % 4
Value col = builder.create<arith::RemUIOp>(loc, k, _4);
// remoteCTAId = (col ^ row) * 4 + col
Value remoteCTAId = builder.create<arith::AddIOp>(
loc,
Value{builder.create<arith::MulIOp>(
loc, Value{builder.create<arith::XOrIOp>(loc, col, row)}, _4)},
col);

// pred0 = threadId % 8 == 0
Value pred0 = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq,
builder.create<arith::RemUIOp>(loc, threadId, _8), _0);
// pred1 = remoteCTAId < numCTAs
Value pred1 = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, remoteCTAId,
builder.create<arith::ConstantIntOp>(loc, numCTAs, 32));

// pred = pred0 & pred1
Value pred = builder.create<arith::AndIOp>(loc, pred0, pred1);
// bufferEmpty arrive
auto arriveOp = builder.create<ttng::MBarrierArriveOp>(loc, bufferEmpty, pred,
remoteCTAId, false, 0);

auto arriveOp = builder.create<ttng::MBarrierArriveOp>(
loc, bufferEmpty, nullptr, nullptr, false, 0);
assert(op.getOperation()->hasAttr("async_task_id"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't quite get this logic around remote-cta. It seems this change gets rid of the remote-cta mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this gets rid of remote-cta. Do you think it is still useful? I'm not sure how it would be used.

setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation()));
}
Expand Down Expand Up @@ -230,17 +194,12 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs,

Value barrierEmptyView = builder.create<ttg::MemDescSubviewOp>(
loc, singleBarrierMemDescType, bufferEmptyArray, idx);
unsigned bufferEmptyCount = numCTAs;
builder.create<ttng::InitBarrierOp>(loc, barrierEmptyView, numCTAs);
builder.create<ttng::InitBarrierOp>(loc, barrierEmptyView,
THREADS_PER_TASK);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this changes from a barrier across CTAs to a barrier within the warp group of 128 threads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes from expecting a master thread form a WG running the barrier arrival to all threads within the WG running it.

htyu marked this conversation as resolved.
Show resolved Hide resolved
}

if (numCTAs == 1) {
builder.create<mlir::gpu::BarrierOp>(loc);
} else {
// Make sure that MBarriers are initialized in all CTAs.
builder.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
builder.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
}
assert(numCTAs == 1 && "remote CTA is not supported yet");
builder.create<mlir::gpu::BarrierOp>(loc);

// Helper function for extracting one index from bufferFullArray.
auto extractBufferFull = [&](Location loc, Value idx) -> Value {
Expand Down Expand Up @@ -291,18 +250,7 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs,
op->erase();
}

// Insert a cluster barrier before the kernel exits. Without this barrier,
// mbarrier_remote_arrive will fail if the remote CTA already exits.
if (numCTAs > 1) {
parentOp->walk([&](triton::FuncOp funcOp) {
Block *block = &funcOp.getBody().front();
auto returnOp = llvm::cast<triton::ReturnOp>(block->getTerminator());
OpBuilder builder(returnOp);
auto loc = returnOp.getLoc();
builder.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
builder.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
});
}
assert(numCTAs == 1 && "remote CTA is not supported yet");
}

#define GEN_PASS_DEF_TRITONGPUWSLOWERING
Expand Down
Loading