Skip to content

Commit

Permalink
squash commits
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Feb 7, 2025
1 parent 2751586 commit 1b88de4
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/AMDAIEEnums.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

#define DEBUG_TYPE "iree-amdaie-no-alias-function-arguments"

namespace mlir::iree_compiler::AMDAIE {

namespace {

SmallVector<std::pair<func::FuncOp, SmallVector<func::CallOp>>>
getFunctionsAndTheirCallers(Operation *rootOp) {
// A mapping from all the function ops in the root op, to their callers.
SmallVector<std::pair<func::FuncOp, SmallVector<func::CallOp>>>
functionsAndCallers;

// A mapping from function symbol names, to their index in
// `functionsAndCallers`.
DenseMap<StringRef, uint32_t> funcOpIndex;

// Find all the function ops
rootOp->walk([&](func::FuncOp funcOp) {
funcOpIndex.insert({funcOp.getSymName(), functionsAndCallers.size()});
SmallVector<func::CallOp> callers;
functionsAndCallers.push_back({funcOp, callers});
});

// Add the callers to the mapping `functionsAndCallers`
rootOp->walk([&](func::CallOp callOp) {
StringRef callee = callOp.getCallee();
auto it = funcOpIndex.find(callee);
if (it != funcOpIndex.end()) {
functionsAndCallers[it->second].second.push_back(callOp);
}
});
return functionsAndCallers;
}

/// Traverse backwards through the definition chain of first operands
/// starting from `initial`, until either a memref.alloc or a amdaie.buffer
/// operation is found. If neither is found, return a failure.
FailureOr<Operation *> getDefiningAllocation(Operation *initial) {
Operation *current = initial;
while (current) {
if (isa<AMDAIE::BufferOp>(current) || isa<memref::AllocOp>(current)) {
return current;
}
if (current->getNumOperands() != 1) {
InFlightDiagnostic message =
initial->emitOpError()
<< "could not be traced back to an allocation operation, "
<< "an operation with " << current->getNumOperands()
<< " operands was encountered while traversing defining ops.";
message.attachNote() << "The operation with multiple operands is: "
<< *current;
return message;
}

current = current->getOperand(0).getDefiningOp();
}
return initial->emitOpError()
<< "could not be traced back to an allocation operation.";
};

/// Return a vector containing for every operand of `callOp`, a bool that is
/// true if the operand is an alias of a memref that does not alias with any
/// other operand.
FailureOr<SmallVector<bool>> nonAliasingMemrefArguments(func::CallOp callOp) {
// Find the allocations that define the memref operands of the call op.
// This vector contains, for each memref operand, a pair containing
// the allocation that defines it, and the index of the operand.
SmallVector<std::pair<Operation *, uint32_t>> memrefs;
for (auto enOperand : llvm::enumerate(callOp.getOperands())) {
Value operand = enOperand.value();
if (!isa<MemRefType>(operand.getType())) continue;
if (operand.getDefiningOp() == nullptr) {
return callOp->emitOpError(
"has an operand with no defining op, failed to find allocation");
}
FailureOr<Operation *> maybeAllocation =
getDefiningAllocation(operand.getDefiningOp());
if (failed(maybeAllocation)) return failure();
Operation *allocation = maybeAllocation.value();
memrefs.push_back({allocation, enOperand.index()});
}

SmallVector<bool> nonAliasingMemref(callOp.getNumOperands(), false);
for (auto memref : memrefs) {
bool isAliasing = false;
for (auto otherMemref : memrefs) {
if (memref.second != otherMemref.second &&
memref.first == otherMemref.first) {
isAliasing = true;
}
}
uint32_t operandIndex = memref.second;
nonAliasingMemref[operandIndex] = !isAliasing;
}
return nonAliasingMemref;
}

class AMDAIENoAliasFunctionArgumentsPass
: public impl::AMDAIENoAliasFunctionArgumentsBase<
AMDAIENoAliasFunctionArgumentsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}
AMDAIENoAliasFunctionArgumentsPass() = default;
AMDAIENoAliasFunctionArgumentsPass(
const AMDAIENoAliasFunctionArgumentsPass &pass){};
void runOnOperation() override;
};

void AMDAIENoAliasFunctionArgumentsPass::runOnOperation() {
Operation *op = getOperation();

// TODO(newling): resolve numerical issues on strix.
{
std::optional<AMDAIEDevice> device = getConfigAMDAIEDeviceFromAncestor(op);
if (device.has_value() && isAie2P(device.value())) return;
}

IRRewriter rewriter(op);
auto functionsAndCallers = getFunctionsAndTheirCallers(op);
for (auto [func, callers] : functionsAndCallers) {
auto numOperands = func.getNumArguments();
SmallVector<bool> nonAliasingMemref(numOperands, true);
for (auto caller : callers) {
assert(numOperands == caller.getNumOperands() &&
"Number of operands in caller and callee do not match");
FailureOr<SmallVector<bool>> maybeNonAliasingArguments =
nonAliasingMemrefArguments(caller);
if (failed(maybeNonAliasingArguments)) {
return signalPassFailure();
}
SmallVector<bool> nonAliasings = maybeNonAliasingArguments.value();
for (uint32_t i = 0; i < nonAliasingMemref.size(); ++i) {
nonAliasingMemref[i] = nonAliasingMemref[i] && nonAliasings[i];
}
}

StringRef noAliasAttrName = LLVM::LLVMDialect::getNoAliasAttrName();
ArrayRef<BlockArgument> args = func.getArguments();
for (auto iter : llvm::enumerate(args)) {
if (nonAliasingMemref[iter.index()]) {
func.setArgAttr(iter.index(), noAliasAttrName, rewriter.getUnitAttr());
}
}
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIENoAliasFunctionArgumentsPass() {
return std::make_unique<AMDAIENoAliasFunctionArgumentsPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ iree_cc_library(
"AMDAIETemporaryAllocBufferization.cpp"
"AMDAIETile.cpp"
"AMDAIETileAndFuse.cpp"
"AMDAIENoAliasFunctionArguments.cpp"
"AMDAIEVectorization.cpp"
"BridgeToAIRPass.cpp"
"DecomposeLinalgExtPackUnPackToAIR.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIETEMPORARYALLOCBUFFERIZATION
#define GEN_PASS_DEF_AMDAIETILE
#define GEN_PASS_DEF_AMDAIETILEANDFUSE
#define GEN_PASS_DEF_AMDAIENOALIASFUNCTIONARGUMENTS
#define GEN_PASS_DEF_AMDAIEVECTORIZATION
#include "iree-amd-aie/Transforms/Passes.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ void addAMDAIEToAIEPasses(OpPassManager &passManager,
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIESinkIntoCorePass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIENoAliasFunctionArgumentsPass());
passManager.addPass(createAMDAIELowerToAIEPass());
passManager.addPass(createAMDAIERemoveMemorySpacePass());
passManager.addPass(createCanonicalizerPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ std::unique_ptr<Pass> createAMDAIETilePass(AMDAIETileOptions options = {});
std::unique_ptr<Pass> createAMDAIETileAndFusePass(
AMDAIETileAndFuseOptions options = {});

/// Create pass to add the llvm.noalias attribute to function arguments
/// where it is safe to do so.
std::unique_ptr<Pass> createAMDAIENoAliasFunctionArgumentsPass();

/// Create pass to propagate pack/unpack ops using upstream patterns.
std::unique_ptr<Pass> createAMDAIEPropagateDataLayoutPass();

Expand Down
21 changes: 19 additions & 2 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
include "iree-amd-aie/IR/AMDAIEDialect.td"
include "mlir/Pass/PassBase.td"

def AMDAIENoAliasFunctionArguments :
Pass<"iree-amdaie-no-alias-function-arguments", ""> {
let summary = "TODO";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIENoAliasFunctionArgumentsPass()";
let description = [{
Where it is safe to do so, give function arguments the `llvm.noalias`
attribute. Motivation: peano can generate much faster code if it knows
that the arguments to a function do not alias.

Basic algorithm: for each function, for each call site, for each operand:
check if the operand's defining memory allocation (currently `memref.alloc`
or `amdaie.buffer`) is the same as any other operand's. If it is the same,
then the function signature cannot have the `llvm.noalias` attribute for
the corresponding argument.
}];
}

def AMDAIEAccessToAcquireRelease :
Pass<"iree-amdaie-access-to-acquire-release", ""> {
let summary = "Convert logical objectFifo access operations to acquire/release "
Expand Down Expand Up @@ -123,13 +140,13 @@ def AMDAIECanonicalizeDoublyStridedOp :
def AMDAIECanonicalizeNpuDmaCpyNd :
Pass<"iree-amdaie-canonicalize-npu-dma-cpy-nd", "ModuleOp"> {
let summary = "Canonicalize npu.dma_cpy_nd operations.";
let description = [{
let description = [{
Canonicalize the offsets/sizes/strides of npu.dma_cpy_nd operations on the L3
side of the data movement, to make them more representative of the DMA in hardware.
This pass ensures the offsets/sizes/strides are of size `nbDimensions`, and that no
dimensions with size>1 have stride=0 except for dimension zero (outer dimension).
This is a HW constraint.
}];
}];
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIECanonicalizeNpuDmaCpyNdPass()";
let options = [
Option<"nbDimensions", "nb-dimensions", "uint64_t", /*default=*/"4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ iree_lit_test_suite(
"tile_and_fuse_matmul_using_scf_forall.mlir"
"tile_and_fuse_convolution_using_scf_forall.mlir"
"tile_copy_using_scf_for.mlir"
"no_alias_function_arguments.mlir"
"vectorization.mlir"
TOOLS
${IREE_LLD_TARGET}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-no-alias-function-arguments)" %s | FileCheck %s


// -----

// check the most basic case: a single argument should always be 'noalias'.
module {
// CHECK: func.func @unary(%arg0: memref<8xi8, 2 : i32> {llvm.noalias}) {
func.func @unary(%arg0: memref<8xi8, 2 : i32>){
return
}
func.func @main(){
%c2 = arith.constant 2 : index
%tile_2_2 = amdaie.tile(%c2, %c2)
%buffer = amdaie.buffer(%tile_2_2) : memref<8xi8, 2 : i32>
call @unary(%buffer) : (memref<8xi8, 2 : i32>) -> ()
return
}
}

// -----


// check that non-memref arguments are ignored.
module {
// CHECK: func.func @unary_with_index(%arg0: memref<8xi8, 2 : i32> {llvm.noalias}, %arg1: index) {
func.func @unary_with_index(%arg0: memref<8xi8, 2 : i32>, %arg1: index){
return
}
func.func @main(){
%c2 = arith.constant 2 : index
%tile_2_2 = amdaie.tile(%c2, %c2)
%buffer = amdaie.buffer(%tile_2_2) : memref<8xi8, 2 : i32>
call @unary_with_index(%buffer, %c2) : (memref<8xi8, 2 : i32>, index) -> ()
return
}
}

// -----

// check that the caller operand needn't be a buffer/allocation, that
// back-tracking to the allocation/buffer works.
module {
// CHECK: func.func @unary_with_chain(%arg0: memref<8xi8> {llvm.noalias}) {
func.func @unary_with_chain(%arg0: memref<8xi8>){
return
}
func.func @main(){
%c2 = arith.constant 2 : index
%alloc = memref.alloc() : memref<8x1x1xi8>
%collapse = memref.collapse_shape %alloc [[0,1],[2]] : memref<8x1x1xi8> into memref<8x1xi8>
%reinterpret = memref.reinterpret_cast %collapse to offset: [0], sizes: [8], strides: [1] : memref<8x1xi8> to memref<8xi8>
call @unary_with_chain(%reinterpret) : (memref<8xi8>) -> ()
return
}
}

// -----

// check a basic case for a function with 2 arguments that no do not alias.
module {
// CHECK: func.func @binary(%arg0: memref<8xi8> {llvm.noalias}, %arg1: memref<8xi8> {llvm.noalias}) {
func.func @binary(%arg0: memref<8xi8>, %arg1: memref<8xi8>){
return
}
func.func @main(){
%alloc = memref.alloc() : memref<8xi8>
%alloc_0 = memref.alloc() : memref<8xi8>
call @binary(%alloc, %alloc_0) : (memref<8xi8>, memref<8xi8>) -> ()
return
}
}

// -----


// Check a case where two of the operands are aliases at one of the call sites.
module {
// CHECK: func.func @ternary_multicall(%arg0: memref<8xi8>, %arg1: memref<8xi8> {llvm.noalias}, %arg2: memref<8xi8>) {
func.func @ternary_multicall(%arg0: memref<8xi8>, %arg1: memref<8xi8>, %arg2: memref<8xi8>){
return
}
func.func @main(){
%alloc = memref.alloc() : memref<8xi8>
%alloc_0 = memref.alloc() : memref<8xi8>
%alloc_1 = memref.alloc() : memref<8xi8>
// This call has no aliasing operands.
call @ternary_multicall(%alloc, %alloc_0, %alloc_1) : (memref<8xi8>, memref<8xi8>, memref<8xi8>) -> ()
// But this call has operand #0 and #2 aliased.
call @ternary_multicall(%alloc_0, %alloc_1, %alloc_0) : (memref<8xi8>, memref<8xi8>, memref<8xi8>) -> ()
return
}
}

// -----


module {
// CHECK: func.func @soak(%arg0: memref<8xi8>, %arg1: index, %arg2: memref<8xi8>, %arg3: memref<8xi8>, %arg4: memref<3xi32> {llvm.noalias}) {
func.func @soak(%arg0: memref<8xi8>, %arg1: index, %arg2: memref<8xi8>, %arg3: memref<8xi8>, %arg4: memref<3xi32>){
return
}
func.func @main(){
%c2 = arith.constant 2 : index
%alloc = memref.alloc() : memref<8xi8>
%alloc_0 = memref.alloc() : memref<2x4xi8>
%reinterpret = memref.reinterpret_cast %alloc_0 to offset: [0], sizes: [8], strides: [1] : memref<2x4xi8> to memref<8xi8>
%alloc_1 = memref.alloc() : memref<8xi8>
%alloc_2 = memref.alloc() : memref<3xi32>
// All non-aliasing:
call @soak(%alloc, %c2, %reinterpret, %alloc_1, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> ()
// Operands #1 and #2 alias:
call @soak(%alloc, %c2, %alloc, %reinterpret, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> ()
// Operands #1 and #3 alias:
call @soak(%alloc, %c2, %alloc_1, %alloc, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> ()
// All non-aliasing:
call @soak(%alloc_1, %c2, %alloc, %reinterpret, %alloc_2) : (memref<8xi8>, index, memref<8xi8>, memref<8xi8>, memref<3xi32>) -> ()
return
}
}

// -----

// TODO(newling) fix numerical issue on strix and remove the constraint that this
// pass does not run if the device is determined to be aie2p (npu4).
#t = #hal.executable.target<"", "", {target_device = "npu4", ukernels = "none"}>
module attributes {hal.executable.target = #t} {
// CHECK: func.func @unary(%arg0: memref<i8>) {
func.func @unary(%arg0: memref<i8>){
return
}
func.func @main(){
%a = memref.alloc() : memref<i8>
call @unary(%a) : (memref<i8>) -> ()
return
}
}

0 comments on commit 1b88de4

Please sign in to comment.