diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 671cc05e963b4..f75e311645426 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -153,7 +153,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ The `assume_alignment` operation takes a memref and an integer alignment value. It returns a new SSA value of the same memref type, but associated with the assumption that the underlying buffer is aligned to the given - alignment. + alignment. If the buffer isn't aligned to the given alignment, its result is poison. This operation doesn't affect the semantics of a program where the @@ -168,7 +168,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; let extraClassDeclaration = [{ MemRefType getType() { return ::llvm::cast(getResult().getType()); } - + Value getViewSource() { return getMemref(); } }]; @@ -176,6 +176,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +def DistinctObjectsOp : MemRef_Op<"distinct_objects", [ + Pure, + DeclareOpInterfaceMethods + // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument + ]> { + let summary = "assumption that acesses to specific memrefs will never alias"; + let description = [{ + The `distinct_objects` operation takes a list of memrefs and returns the same + memrefs, with the additional assumption that accesses to them will never + alias with each other. This means that loads and stores to different + memrefs in the list can be safely reordered. + + If the memrefs do alias, the load/store behavior is undefined. This + operation doesn't affect the semantics of a valid program. It is + intended for optimization purposes, allowing the compiler to generate more + efficient code based on the non-aliasing assumption. The optimization is + best-effort. + + Example: + + ```mlir + %1, %2 = memref.distinct_objects %a, %b : memref, memref + ``` + }]; + let arguments = (ins Variadic:$operands); + let results = (outs Variadic:$results); + + let assemblyFormat = "$operands attr-dict `:` type($operands)"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 262e0e7a30c63..c62137721a2b9 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering } }; +struct DistinctObjectsOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::DistinctObjectsOp>::ConvertOpToLLVMPattern; + explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} + + LogicalResult + matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange operands = adaptor.getOperands(); + if (operands.size() <= 1) { + // Fast path. + rewriter.replaceOp(op, operands); + return success(); + } + + Location loc = op.getLoc(); + SmallVector ptrs; + for (auto [origOperand, newOperand] : + llvm::zip_equal(op.getOperands(), operands)) { + auto memrefType = cast(origOperand.getType()); + MemRefDescriptor memRefDescriptor(newOperand); + Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), + memrefType); + ptrs.push_back(ptr); + } + + auto cond = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1); + // Generate separate_storage assumptions for each pair of pointers. + for (auto i : llvm::seq(ptrs.size() - 1)) { + for (auto j : llvm::seq(i + 1, ptrs.size())) { + Value ptr1 = ptrs[i]; + Value ptr2 = ptrs[j]; + LLVM::AssumeOp::create(rewriter, loc, cond, + LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2); + } + } + + rewriter.replaceOp(op, operands); + return success(); + } +}; + // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. @@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( patterns.add< AllocaOpLowering, AllocaScopeOpLowering, - AtomicRMWOpLowering, AssumeAlignmentOpLowering, + AtomicRMWOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, + DistinctObjectsOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, + MemorySpaceCastOpLowering, PrefetchOpLowering, RankOpLowering, - ReassociatingReshapeOpConversion, ReassociatingReshapeOpConversion, + ReassociatingReshapeOpConversion, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 5d15d5f6e3de4..0bca922b0c804 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -542,6 +542,29 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) { return getMemref(); } +//===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +LogicalResult DistinctObjectsOp::verify() { + if (getOperandTypes() != getResultTypes()) + return emitOpError("operand types and result types must match"); + + if (getOperandTypes().empty()) + return emitOpError("expected at least one operand"); + + return success(); +} + +LogicalResult DistinctObjectsOp::inferReturnTypes( + MLIRContext * /*context*/, std::optional /*location*/, + ValueRange operands, DictionaryAttr /*attributes*/, + OpaqueProperties /*properties*/, RegionRange /*regions*/, + SmallVectorImpl &inferredReturnTypes) { + llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes)); + return success(); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 45b1a1f1ca40c..0cbe064572911 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) { // ----- +// ALL-LABEL: func @distinct_objects +// ALL-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) +func.func @distinct_objects(%arg0: memref, %arg1: memref, %arg2: memref) -> (memref, memref, memref) { +// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1 + %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref, memref, memref + return %1, %2, %3 : memref, memref, memref +} + +// ----- + +// ALL-LABEL: func @distinct_objects_noop +// ALL-SAME: (%[[ARG0:.*]]: memref) +func.func @distinct_objects_noop(%arg0: memref) -> memref { +// 1-operand version is noop +// ALL-NEXT: return %[[ARG0]] + %1 = memref.distinct_objects %arg0 : memref + return %1 : memref +} + +// ----- + // CHECK-LABEL: func @assume_alignment_w_offset // CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) { diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 3f96d907632b7..5ff292058ccc1 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape( into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>> return } + +// ----- + +func.func @distinct_objects_types_mismatch(%arg0: memref, %arg1: memref) -> (memref, memref) { + // expected-error @+1 {{operand types and result types must match}} + %0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref, memref) -> (memref, memref) + return %0, %1 : memref, memref +} + +// ----- + +func.func @distinct_objects_0_operands() { + // expected-error @+1 {{expected at least one operand}} + "memref.distinct_objects"() : () -> () + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 6c2298a3f8acb..a90c9505a8405 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) { return } +// CHECK-LABEL: func @distinct_objects +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) +func.func @distinct_objects(%arg0: memref, %arg1: memref, %arg2: memref) -> (memref, memref, memref) { + // CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref, memref, memref + %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref, memref, memref + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref, memref, memref + return %1, %2, %3 : memref, memref, memref +} + // CHECK-LABEL: func @expand_collapse_shape_static func.func @expand_collapse_shape_static( %arg0: memref<3x4x5xf32>,