Skip to content

Commit

Permalink
Forward fixes to build on newer version of llvm (triton-lang#2388)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwbaker authored Sep 26, 2023
1 parent 7432fff commit 2d28b09
Showing 5 changed files with 19 additions and 18 deletions.
13 changes: 6 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
@@ -761,16 +761,15 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,

// Create a new loop before the existing one, with the extra operands.
rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
auto operands = llvm::to_vector<4>(loop.getInitArgs());
operands.append(newIterOperands.begin(), newIterOperands.end());
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands);
newLoop.getBody()->erase();

newLoop.getLoopBody().getBlocks().splice(
newLoop.getLoopBody().getBlocks().begin(),
loop.getLoopBody().getBlocks());
newLoop.getRegion().getBlocks().splice(
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
for (Value operand : newIterOperands)
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());

@@ -805,9 +804,9 @@ static void rewriteSlice(SetVector<Value> &slice,
for (auto arg : forOp.getRegionIterArgs()) {
if (slice.count(arg)) {
OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg);
argMapping.push_back(
std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal),
forOp.getNumIterOperands() + newOperands.size()));
argMapping.push_back(std::make_pair(
forOp.getResultForOpOperand(initVal).getResultNumber(),
forOp.getInitArgs().size() + newOperands.size()));
newOperands.push_back(mapping.lookup(initVal.get()));
}
}
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
@@ -613,7 +613,7 @@ struct ForOpDeadArgElimination : public OpRewritePattern<scf::ForOp> {
Value yieldOperand =
forOwner.getBody()->getTerminator()->getOperand(iterIdx);
markLive(yieldOperand);
markLive(forOwner.getIterOperands()[iterIdx]);
markLive(forOwner.getInitArgs()[iterIdx]);
}
}
SmallVector<unsigned> deadArg;
14 changes: 7 additions & 7 deletions lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
@@ -523,9 +523,9 @@ class TritonGPURewriteTensorPointerPass
std::stack<Operation *> &eraser,
DenseSet<Value> &valueToRemove) {
// Generate new iteration operands and set rewrited information
SmallVector<Value> oldIterOperands = op.getIterOperands();
SmallVector<Value> newIterOperands = op.getIterOperands();
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
SmallVector<Value> oldIterOperands = llvm::to_vector(op.getInitArgs());
SmallVector<Value> newIterOperands = llvm::to_vector(op.getInitArgs());
for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size;
++i, ++oldI) {
if (!tt::isTensorPointerType(newIterOperands[i].getType()))
continue;
@@ -550,7 +550,7 @@ class TritonGPURewriteTensorPointerPass
// mapping. It may refer to a value in the old loop, but we will rewrite it
// later
IRMapping mapping;
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size();
++i, ++oldI) {
auto oldRegionIterArg = op.getRegionIterArg(oldI);
if (tt::isTensorPointerType(oldRegionIterArg.getType()) &&
@@ -586,7 +586,7 @@ class TritonGPURewriteTensorPointerPass
valueToRemove.insert(v);

// Replace later usages
assert(op.getNumResults() == op.getNumIterOperands());
assert(op.getNumResults() == op.getInitArgs().size());
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
auto oldResult = op.getResult(oldI);
if (tt::isTensorPointerType(oldResult.getType()) &&
@@ -787,8 +787,8 @@ class TritonGPURewriteTensorPointerPass
}
}
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
SmallVector<Value> iterOperands = forOp.getIterOperands();
for (unsigned i = 0, size = forOp.getNumIterOperands(); i < size; ++i) {
SmallVector<Value> iterOperands = llvm::to_vector(forOp.getInitArgs());
for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) {
if (tt::isTensorPointerType(iterOperands[i].getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]);
if (shouldRemove(makeTensorPtrOp, computeCapability))
4 changes: 3 additions & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp
Original file line number Diff line number Diff line change
@@ -153,7 +153,9 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,

Value newIdx =
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, curRoleId);
persistentForOp.setIterArg(persistentForOp.getNumIterOperands() - 1, newIdx);
persistentForOp.getInitArgsMutable()
.slice(persistentForOp.getInitArgs().size() - 1, 1)
.assign(newIdx);
auto yield =
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
auto idxPlusOneOp =
4 changes: 2 additions & 2 deletions lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp
Original file line number Diff line number Diff line change
@@ -162,7 +162,7 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,

// Copy iter operands of forOp
SmallVector<Value> newLoopArgs;
for (auto operand : forOp.getIterOperands())
for (auto operand : llvm::to_vector(forOp.getInitArgs()))
newLoopArgs.push_back(operand);

// Append initial value of pipelineIdx to newLoopArgs
@@ -302,7 +302,7 @@ DenseMap<AgentId, scf::ForOp> createForOpsForEachAgentId(scf::ForOp forOp) {
// Prepare newLoopArgs
SmallVector<Value> newLoopArgs;
for (unsigned argNumber : usedArgs)
newLoopArgs.push_back(forOp.getIterOperands()[argNumber]);
newLoopArgs.push_back(forOp.getInitArgs()[argNumber]);

// Create newForOp
builder.setAgentIdsFromArray({agentId});

0 comments on commit 2d28b09

Please sign in to comment.