From c077ab6558c91327afde907f5fdb58f9c286f2e7 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Sat, 14 Dec 2024 00:59:27 -0800 Subject: [PATCH 1/2] optimize consumer release --- .../TritonGPU/Transforms/WSLowering.cpp | 44 ++----------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp index c2bf31fc5..adc009954 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -149,44 +149,8 @@ void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, Value bufferEmpty, int numCTAs) { auto loc = op.getLoc(); - Value _0 = builder.create(loc, 0, 32); - Value _4 = builder.create(loc, 4, 32); - Value _8 = builder.create(loc, 8, 32); - Value _32 = builder.create(loc, 32, 32); - Value _threadPerTask = - builder.create(loc, THREADS_PER_TASK, 32); - - // threadId = threadId % THREADS_PER_TASK - Value threadId = builder.create( - loc, createThreadIdOp(builder, loc), _threadPerTask); - // k = threadId / 8 - Value k = builder.create(loc, threadId, _8); - // row = k / 4 - Value row = builder.create(loc, k, _4); - // col = k % 4 - Value col = builder.create(loc, k, _4); - // remoteCTAId = (col ^ row) * 4 + col - Value remoteCTAId = builder.create( - loc, - Value{builder.create( - loc, Value{builder.create(loc, col, row)}, _4)}, - col); - - // pred0 = threadId % 8 == 0 - Value pred0 = builder.create( - loc, arith::CmpIPredicate::eq, - builder.create(loc, threadId, _8), _0); - // pred1 = remoteCTAId < numCTAs - Value pred1 = builder.create( - loc, arith::CmpIPredicate::ult, remoteCTAId, - builder.create(loc, numCTAs, 32)); - - // pred = pred0 & pred1 - Value pred = builder.create(loc, pred0, pred1); - // bufferEmpty arrive - auto arriveOp = builder.create(loc, bufferEmpty, pred, - remoteCTAId, false, 0); - + auto arriveOp = builder.create( + loc, bufferEmpty, nullptr, nullptr, false, 0); assert(op.getOperation()->hasAttr("async_task_id")); setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); } @@ -230,8 +194,8 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, Value barrierEmptyView = builder.create( loc, singleBarrierMemDescType, bufferEmptyArray, idx); - unsigned bufferEmptyCount = numCTAs; - builder.create(loc, barrierEmptyView, numCTAs); + builder.create(loc, barrierEmptyView, + THREADS_PER_TASK); } if (numCTAs == 1) { From c2c95c8cdd48d2c24cae6dee48add8555b270fbe Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Thu, 9 Jan 2025 09:55:50 -0800 Subject: [PATCH 2/2] Remove remote CTA support --- .../TritonGPU/Transforms/WSLowering.cpp | 22 +++---------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp index adc009954..d8b9c76f5 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -198,13 +198,8 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, THREADS_PER_TASK); } - if (numCTAs == 1) { - builder.create(loc); - } else { - // Make sure that MBarriers are initialized in all CTAs. - builder.create(loc, false); - builder.create(loc); - } + assert(numCTAs == 1 && "remote CTA is not supported yet"); + builder.create(loc); // Helper function for extracting one index from bufferFullArray. auto extractBufferFull = [&](Location loc, Value idx) -> Value { @@ -255,18 +250,7 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, op->erase(); } - // Insert a cluster barrier before the kernel exits. Without this barrier, - // mbarrier_remote_arrive will fail if the remote CTA already exits. - if (numCTAs > 1) { - parentOp->walk([&](triton::FuncOp funcOp) { - Block *block = &funcOp.getBody().front(); - auto returnOp = llvm::cast(block->getTerminator()); - OpBuilder builder(returnOp); - auto loc = returnOp.getLoc(); - builder.create(loc, false); - builder.create(loc); - }); - } + assert(numCTAs == 1 && "remote CTA is not supported yet"); } #define GEN_PASS_DEF_TRITONGPUWSLOWERING