-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
333 additions
and
2 deletions.
There are no files selected for viewing
169 changes: 169 additions & 0 deletions
169
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIENoAliasFunctionArguments.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree-amd-aie/IR/AMDAIEDialect.h" | ||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h" | ||
#include "iree-amd-aie/aie_runtime/AMDAIEEnums.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-no-alias-function-arguments" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
namespace { | ||
|
||
SmallVector<std::pair<func::FuncOp, SmallVector<func::CallOp>>> | ||
getFunctionsAndTheirCallers(Operation *rootOp) { | ||
// A mapping from all the function ops in the root op, to their callers. | ||
SmallVector<std::pair<func::FuncOp, SmallVector<func::CallOp>>> | ||
functionsAndCallers; | ||
|
||
// A mapping from function symbol names, to their index in | ||
// `functionsAndCallers`. | ||
DenseMap<StringRef, uint32_t> funcOpIndex; | ||
|
||
// Find all the function ops | ||
rootOp->walk([&](func::FuncOp funcOp) { | ||
funcOpIndex.insert({funcOp.getSymName(), functionsAndCallers.size()}); | ||
SmallVector<func::CallOp> callers; | ||
functionsAndCallers.push_back({funcOp, callers}); | ||
}); | ||
|
||
// Add the callers to the mapping `functionsAndCallers` | ||
rootOp->walk([&](func::CallOp callOp) { | ||
StringRef callee = callOp.getCallee(); | ||
auto it = funcOpIndex.find(callee); | ||
if (it != funcOpIndex.end()) { | ||
functionsAndCallers[it->second].second.push_back(callOp); | ||
} | ||
}); | ||
return functionsAndCallers; | ||
} | ||
|
||
/// Traverse backwards through the definition chain of first operands | ||
/// starting from `initial`, until either a memref.alloc or a amdaie.buffer | ||
/// operation is found. If neither is found, return a failure. | ||
FailureOr<Operation *> getDefiningAllocation(Operation *initial) { | ||
Operation *current = initial; | ||
while (current) { | ||
if (isa<AMDAIE::BufferOp>(current) || isa<memref::AllocOp>(current)) { | ||
return current; | ||
} | ||
if (current->getNumOperands() != 1) { | ||
InFlightDiagnostic message = | ||
initial->emitOpError() | ||
<< "could not be traced back to an allocation operation, " | ||
<< "an operation with " << current->getNumOperands() | ||
<< " operands was encountered while traversing defining ops."; | ||
message.attachNote() << "The operation with multiple operands is: " | ||
<< *current; | ||
return message; | ||
} | ||
|
||
current = current->getOperand(0).getDefiningOp(); | ||
} | ||
return initial->emitOpError() | ||
<< "could not be traced back to an allocation operation."; | ||
}; | ||
|
||
/// Return a vector containing for every operand of `callOp`, a bool that is | ||
/// true if the operand is an alias of a memref that does not alias with any | ||
/// other operand. | ||
FailureOr<SmallVector<bool>> nonAliasingMemrefArguments(func::CallOp callOp) { | ||
// Find the allocations that define the memref operands of the call op. | ||
// This vector contains, for each memref operand, a pair containing | ||
// the allocation that defines it, and the index of the operand. | ||
SmallVector<std::pair<Operation *, uint32_t>> memrefs; | ||
for (auto enOperand : llvm::enumerate(callOp.getOperands())) { | ||
Value operand = enOperand.value(); | ||
if (!isa<MemRefType>(operand.getType())) continue; | ||
if (operand.getDefiningOp() == nullptr) { | ||
return callOp->emitOpError( | ||
"has an operand with no defining op, failed to find allocation"); | ||
} | ||
FailureOr<Operation *> maybeAllocation = | ||
getDefiningAllocation(operand.getDefiningOp()); | ||
if (failed(maybeAllocation)) return failure(); | ||
Operation *allocation = maybeAllocation.value(); | ||
memrefs.push_back({allocation, enOperand.index()}); | ||
} | ||
|
||
SmallVector<bool> nonAliasingMemref(callOp.getNumOperands(), false); | ||
for (auto memref : memrefs) { | ||
bool isAliasing = false; | ||
for (auto otherMemref : memrefs) { | ||
if (memref.second != otherMemref.second && | ||
memref.first == otherMemref.first) { | ||
isAliasing = true; | ||
} | ||
} | ||
uint32_t operandIndex = memref.second; | ||
nonAliasingMemref[operandIndex] = !isAliasing; | ||
} | ||
return nonAliasingMemref; | ||
} | ||
|
||
class AMDAIENoAliasFunctionArgumentsPass | ||
: public impl::AMDAIENoAliasFunctionArgumentsBase< | ||
AMDAIENoAliasFunctionArgumentsPass> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect>(); | ||
} | ||
AMDAIENoAliasFunctionArgumentsPass() = default; | ||
AMDAIENoAliasFunctionArgumentsPass( | ||
const AMDAIENoAliasFunctionArgumentsPass &pass){}; | ||
void runOnOperation() override; | ||
}; | ||
|
||
void AMDAIENoAliasFunctionArgumentsPass::runOnOperation() { | ||
Operation *op = getOperation(); | ||
|
||
// TODO(newling): resolve numerical issues on strix. | ||
{ | ||
std::optional<AMDAIEDevice> device = getConfigAMDAIEDeviceFromAncestor(op); | ||
if (device.has_value() && isAie2P(device.value())) return; | ||
} | ||
|
||
IRRewriter rewriter(op); | ||
auto functionsAndCallers = getFunctionsAndTheirCallers(op); | ||
for (auto [func, callers] : functionsAndCallers) { | ||
auto numOperands = func.getNumArguments(); | ||
SmallVector<bool> nonAliasingMemref(numOperands, true); | ||
for (auto caller : callers) { | ||
assert(numOperands == caller.getNumOperands() && | ||
"Number of operands in caller and callee do not match"); | ||
FailureOr<SmallVector<bool>> maybeNonAliasingArguments = | ||
nonAliasingMemrefArguments(caller); | ||
if (failed(maybeNonAliasingArguments)) { | ||
return signalPassFailure(); | ||
} | ||
SmallVector<bool> nonAliasings = maybeNonAliasingArguments.value(); | ||
for (uint32_t i = 0; i < nonAliasingMemref.size(); ++i) { | ||
nonAliasingMemref[i] = nonAliasingMemref[i] && nonAliasings[i]; | ||
} | ||
} | ||
|
||
StringRef noAliasAttrName = LLVM::LLVMDialect::getNoAliasAttrName(); | ||
ArrayRef<BlockArgument> args = func.getArguments(); | ||
for (auto iter : llvm::enumerate(args)) { | ||
if (nonAliasingMemref[iter.index()]) { | ||
func.setArgAttr(iter.index(), noAliasAttrName, rewriter.getUnitAttr()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> createAMDAIENoAliasFunctionArgumentsPass() { | ||
return std::make_unique<AMDAIENoAliasFunctionArgumentsPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::AMDAIE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
...iler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/no_alias_function_arguments.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-no-alias-function-arguments)" %s | FileCheck %s | ||
|
||
|
||
// ----- | ||
|
||
// check the most basic case: a single argument should always be 'noalias'. | ||
module { | ||
// CHECK: func.func @unary(%arg0: memref<8xi8, 2 : i32> {llvm.noalias}) { | ||
func.func @unary(%arg0: memref<8xi8, 2 : i32>){ | ||
return | ||
} | ||
func.func @main(){ | ||
%c2 = arith.constant 2 : index | ||
%tile_2_2 = amdaie.tile(%c2, %c2) | ||
%buffer = amdaie.buffer(%tile_2_2) : memref<8xi8, 2 : i32> | ||
call @unary(%buffer) : (memref<8xi8, 2 : i32>) -> () | ||
return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
|
||
// check that non-memref arguments are ignored. | ||
module { | ||
// CHECK: func.func @unary_with_index(%arg0: memref<8xi8, 2 : i32> {llvm.noalias}, %arg1: index) { | ||
func.func @unary_with_index(%arg0: memref<8xi8, 2 : i32>, %arg1: index){ | ||
return | ||
} | ||
func.func @main(){ | ||
%c2 = arith.constant 2 : index | ||
%tile_2_2 = amdaie.tile(%c2, %c2) | ||
%buffer = amdaie.buffer(%tile_2_2) : memref<8xi8, 2 : i32> | ||
call @unary_with_index(%buffer, %c2) : (memref<8xi8, 2 : i32>, index) -> () | ||
return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// check that the caller operand needn't be a buffer/allocation, that | ||
// back-tracking to the allocation/buffer works. | ||
module { | ||
// CHECK: func.func @unary_with_chain(%arg0: memref<8xi8> {llvm.noalias}) { | ||
func.func @unary_with_chain(%arg0: memref<8xi8>){ | ||
return | ||
} | ||
func.func @main(){ | ||
%c2 = arith.constant 2 : index | ||
%alloc = memref.alloc() : memref<8x1x1xi8> | ||
%collapse = memref.collapse_shape %alloc [[0,1],[2]] : memref<8x1x1xi8> into memref<8x1xi8> | ||
%reinterpret = memref.reinterpret_cast %collapse to offset: [0], sizes: [8], strides: [1] : memref<8x1xi8> to memref<8xi8> | ||
call @unary_with_chain(%reinterpret) : (memref<8xi8>) -> () | ||
return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// check a basic case for a function with 2 arguments that no do not alias. | ||
module { | ||
// CHECK: func.func @binary(%arg0: memref<8xi8> {llvm.noalias}, %arg1: memref<8xi8> {llvm.noalias}) { | ||
func.func @binary(%arg0: memref<8xi8>, %arg1: memref<8xi8>){ | ||
return | ||
} | ||
func.func @main(){ | ||
%alloc = memref.alloc() : memref<8xi8> | ||
%alloc_0 = memref.alloc() : memref<8xi8> | ||
call @binary(%alloc, %alloc_0) : (memref<8xi8>, memref<8xi8>) -> () | ||
return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
|
||
// Check a case where two of the operands are aliases at one of the call sites. | ||
module { | ||
// CHECK: func.func @ternary_multicall(%arg0: memref<8xi8>, %arg1: memref<8xi8> {llvm.noalias}, %arg2: memref<8xi8>) { | ||
func.func @ternary_multicall(%arg0: memref<8xi8>, %arg1: memref<8xi8>, %arg2: memref<8xi8>){ | ||
return | ||
} | ||
func.func @main(){ | ||
%alloc = memref.alloc() : memref<8xi8> | ||
%alloc_0 = memref.alloc() : memref<8xi8> | ||
%alloc_1 = memref.alloc() : memref<8xi8> | ||
// This call has no aliasing operands. | ||
call @ternary_multicall(%alloc, %alloc_0, %alloc_1) : (memref<8xi8>, memref<8xi8>, memref<8xi8>) -> () | ||
// But this call has operand #0 and #2 aliased. | ||
call @ternary_multicall(%alloc_0, %alloc_1, %alloc_0) : (memref<8xi8>, memref<8xi8>, memref<8xi8>) -> () | ||
return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
|
||
module { | ||
// CHECK: func.func @soak(%arg0: memref<8xi8>, %arg1: index, %arg2: memref<8xi8>, %arg3: memref<8xi8>, %arg4: memref<3xi32> {llvm.noalias}) { | ||
func.func @soak(%arg0: memref<8xi8>, %arg1: index, %arg2: memref<8xi8>, %arg3: memref<8xi8>, %arg4: memref<3xi32>){ | ||
return | ||
} | ||
func.func @main(){ | ||
%c2 = arith.constant 2 : index | ||
%alloc = memref.alloc() : memref<8xi8> | ||
%alloc_0 = memref.alloc() : memref<2x4xi8> | ||
%reinterpret = memref.reinterpret_cast %alloc_0 to offset: [0], sizes: [8], strides: [1] : memref<2x4xi8> to memref<8xi8> | ||
%alloc_1 = memref.alloc() : memref<8xi8> | ||
%alloc_2 = memref.alloc() : memref<3xi32> | ||
// All non-aliasing: | ||
call @soak(%alloc, %c2, %reinterpret, %alloc_1, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> () | ||
// Operands #1 and #2 alias: | ||
call @soak(%alloc, %c2, %alloc, %reinterpret, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> () | ||
// Operands #1 and #3 alias: | ||
call @soak(%alloc, %c2, %alloc_1, %alloc, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> () | ||
// All non-aliasing: | ||
call @soak(%alloc_1, %c2, %alloc, %reinterpret, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> () | ||
return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// TODO(newling) fix numerical issue on strix and remove the constraint that this | ||
// pass does not run if the device is determined to be aie2p (npu4). | ||
#t = #hal.executable.target<"", "", {target_device = "npu4", ukernels = "none"}> | ||
module attributes {hal.executable.target = #t} { | ||
// CHECK: func.func @unary(%arg0: memref<i8>) { | ||
func.func @unary(%arg0: memref<i8>){ | ||
return | ||
} | ||
func.func @main(){ | ||
%a = memref.alloc() : memref<i8> | ||
call @unary(%a) : (memref<i8>) -> () | ||
return | ||
} | ||
} |