-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WSLowering] consumer release on each thread instead of master thread #10
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,44 +149,8 @@ void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, | |
void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, | ||
Value bufferEmpty, int numCTAs) { | ||
auto loc = op.getLoc(); | ||
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32); | ||
Value _4 = builder.create<arith::ConstantIntOp>(loc, 4, 32); | ||
Value _8 = builder.create<arith::ConstantIntOp>(loc, 8, 32); | ||
Value _32 = builder.create<arith::ConstantIntOp>(loc, 32, 32); | ||
Value _threadPerTask = | ||
builder.create<arith::ConstantIntOp>(loc, THREADS_PER_TASK, 32); | ||
|
||
// threadId = threadId % THREADS_PER_TASK | ||
Value threadId = builder.create<arith::RemUIOp>( | ||
loc, createThreadIdOp(builder, loc), _threadPerTask); | ||
// k = threadId / 8 | ||
Value k = builder.create<arith::DivUIOp>(loc, threadId, _8); | ||
// row = k / 4 | ||
Value row = builder.create<arith::DivUIOp>(loc, k, _4); | ||
// col = k % 4 | ||
Value col = builder.create<arith::RemUIOp>(loc, k, _4); | ||
// remoteCTAId = (col ^ row) * 4 + col | ||
Value remoteCTAId = builder.create<arith::AddIOp>( | ||
loc, | ||
Value{builder.create<arith::MulIOp>( | ||
loc, Value{builder.create<arith::XOrIOp>(loc, col, row)}, _4)}, | ||
col); | ||
|
||
// pred0 = threadId % 8 == 0 | ||
Value pred0 = builder.create<arith::CmpIOp>( | ||
loc, arith::CmpIPredicate::eq, | ||
builder.create<arith::RemUIOp>(loc, threadId, _8), _0); | ||
// pred1 = remoteCTAId < numCTAs | ||
Value pred1 = builder.create<arith::CmpIOp>( | ||
loc, arith::CmpIPredicate::ult, remoteCTAId, | ||
builder.create<arith::ConstantIntOp>(loc, numCTAs, 32)); | ||
|
||
// pred = pred0 & pred1 | ||
Value pred = builder.create<arith::AndIOp>(loc, pred0, pred1); | ||
// bufferEmpty arrive | ||
auto arriveOp = builder.create<ttng::MBarrierArriveOp>(loc, bufferEmpty, pred, | ||
remoteCTAId, false, 0); | ||
|
||
auto arriveOp = builder.create<ttng::MBarrierArriveOp>( | ||
loc, bufferEmpty, nullptr, nullptr, false, 0); | ||
assert(op.getOperation()->hasAttr("async_task_id")); | ||
setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); | ||
} | ||
|
@@ -230,17 +194,12 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, | |
|
||
Value barrierEmptyView = builder.create<ttg::MemDescSubviewOp>( | ||
loc, singleBarrierMemDescType, bufferEmptyArray, idx); | ||
unsigned bufferEmptyCount = numCTAs; | ||
builder.create<ttng::InitBarrierOp>(loc, barrierEmptyView, numCTAs); | ||
builder.create<ttng::InitBarrierOp>(loc, barrierEmptyView, | ||
THREADS_PER_TASK); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this changes from a barrier across CTAs to a barrier within the warp group of 128 threads? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This changes from expecting a master thread form a WG running the barrier arrival to all threads within the WG running it.
htyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
if (numCTAs == 1) { | ||
builder.create<mlir::gpu::BarrierOp>(loc); | ||
} else { | ||
// Make sure that MBarriers are initialized in all CTAs. | ||
builder.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false); | ||
builder.create<triton::nvidia_gpu::ClusterWaitOp>(loc); | ||
} | ||
assert(numCTAs == 1 && "remote CTA is not supported yet"); | ||
builder.create<mlir::gpu::BarrierOp>(loc); | ||
|
||
// Helper function for extracting one index from bufferFullArray. | ||
auto extractBufferFull = [&](Location loc, Value idx) -> Value { | ||
|
@@ -291,18 +250,7 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, | |
op->erase(); | ||
} | ||
|
||
// Insert a cluster barrier before the kernel exits. Without this barrier, | ||
// mbarrier_remote_arrive will fail if the remote CTA already exits. | ||
if (numCTAs > 1) { | ||
parentOp->walk([&](triton::FuncOp funcOp) { | ||
Block *block = &funcOp.getBody().front(); | ||
auto returnOp = llvm::cast<triton::ReturnOp>(block->getTerminator()); | ||
OpBuilder builder(returnOp); | ||
auto loc = returnOp.getLoc(); | ||
builder.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false); | ||
builder.create<triton::nvidia_gpu::ClusterWaitOp>(loc); | ||
}); | ||
} | ||
assert(numCTAs == 1 && "remote CTA is not supported yet"); | ||
} | ||
|
||
#define GEN_PASS_DEF_TRITONGPUWSLOWERING | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't quite get this logic around remote-cta. It seems this change gets rid of the remote-cta mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this gets rid of remote-cta. Do you think it is still useful? I'm not sure how it would be used.