Skip to content

Conversation

@srcarroll
Copy link
Contributor

@srcarroll srcarroll commented Oct 9, 2025

This PR modifies the definition of linalg::MapOp so that it has the same structure of linalg::GenericOp and all other linalg ops. Mainly, it adds an out bbarg for the body of the op. Although the out arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the GenericOp structure. For example, linalg-generalize-named-ops avoided converting linalg.map purely because it didn't have the structure to do so. Moreover, although some fusion patterns are applied explicitly to GenericOp, we can change them to be applied to the base LinalgOp which will enable fusion for any fusion-compatible linalg op, but that requires the op having a generic structure. So these changes will enable us to use existing generic transformation patterns on MapOp that weren't possible before. They can either be applied to MapOp directly or applied after converting to GenericOp.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 9, 2025

@rengolin this is a followup on our discussion on my other PR #144922 (comment). Just want to make sure this is in the right direction and doesn't conflict with what you or others have in mind for the linalg refactor before I spend time updating all the tests.

update: it wasn't that much work, so just did it anyway

const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
bool initFirst = false) {
bool initFirst = false, bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just the option with least amount of changes but can refactor if necessary


static bool canUseShortForm(Block *body, bool initFirst = false) {
static bool canUseShortForm(Block *body, bool initFirst = false,
bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, this was the option with least amount of changes. will refactor if desired

@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2025

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

Changes

This PR modifies the definition of linalg::MapOp so that it has the same structure of linalg::GenericOp and all other linalg ops. Mainly, it adds an out bbarg for the body of the op. Although the out arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the GenericOp structure. For example, linalg-generalize-named-ops avoided converting linalg.map purely because it didn't have the structure to do so. If GenericOp can have unused bbargs, then ALL linalg ops should be allowed that as well.


Full diff: https://github.com/llvm/llvm-project/pull/162742.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (-4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+24-13)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+2)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+14-8)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/one-shot-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+9-9)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+1-1)
  • (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+3-3)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
-    SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
-      return getDpsInputOperands();
-    }
-
     bool payloadUsesValueFromOperand(OpOperand * opOperand) {
       if (isDpsInit(opOperand)) return false;
       return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..7ccba6143637e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
                                      OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
     setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
 }
 
 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
 
   if (bodyBuild)
     buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, /*outputs=*/{}, bodyBuild);
+                       inputs, /*outputs=*/{init}, bodyBuild);
 }
 
 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
                                  const OperationName &payloadOpName,
                                  const NamedAttrList &payloadOpAttrs,
                                  ArrayRef<Value> operands,
-                                 bool initFirst = false) {
+                                 bool initFirst = false, bool mapInit = true) {
   OpBuilder b(parser.getContext());
   Region *body = result.addRegion();
   Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
   // If initFirst flag is enabled, we consider init as the first position of
   // payload operands.
   if (initFirst) {
-    payloadOpOperands.push_back(block.getArguments().back());
+    if (mapInit)
+      payloadOpOperands.push_back(block.getArguments().back());
     for (const auto &arg : block.getArguments().drop_back())
       payloadOpOperands.push_back(arg);
   } else {
     payloadOpOperands = {block.getArguments().begin(),
-                         block.getArguments().end()};
+                         block.getArguments().end() - int(!mapInit)};
   }
 
   Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   if (payloadOpName.has_value()) {
     if (!result.operands.empty())
       addBodyWithPayloadOp(parser, result, payloadOpName.value(),
-                           payloadOpAttrs,
-                           ArrayRef(result.operands).drop_back());
+                           payloadOpAttrs, ArrayRef(result.operands), false,
+                           false);
     else
       result.addRegion();
   } else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+                            bool mapInit = true) {
+  // `intFirst == true` implies that we want to map init arg
+  if (initFirst && !mapInit)
+    return false;
   // Check if the body can be printed in short form. The following 4 conditions
   // must be satisfied:
 
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
   // 2) The payload op must have the same number of operands as the number of
   //    block arguments.
   if (payload.getNumOperands() == 0 ||
-      payload.getNumOperands() != body->getNumArguments())
+      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
     return false;
 
   // 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
     }
   } else {
     for (const auto &[operand, bbArg] :
-         llvm::zip(payload.getOperands(), body->getArguments())) {
+         llvm::zip(payload.getOperands(),
+                   body->getArguments().drop_back(int(!mapInit)))) {
       if (bbArg != operand)
         return false;
     }
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  bool useShortForm = canUseShortForm(mapper);
+  bool useShortForm =
+      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
   if (useShortForm) {
     printShortForm(p, &mapper->getOperations().front());
   }
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
   auto *bodyBlock = getBody();
   auto blockArgs = bodyBlock->getArguments();
 
-  // Checks if the number of `inputs` match the arity of the `mapper` region.
-  if (getInputs().size() != blockArgs.size())
+  // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+  // region.
+  if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
-                         << getInputs().size() << " and " << blockArgs.size();
+                         << getInputs().size() + 1 << " and "
+                         << blockArgs.size();
 
   // The parameters of mapper should all match the element type of inputs.
   for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
-  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
-  // trivially generalize a `linalg.map`, as it does not use the output as
-  // region arguments in the block.
-  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+  // Bailout if `linalgOp` is already a generic.
+  if (isa<GenericOp>(linalgOp))
     return failure();
   // Check if the operation has exactly one region.
   if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e47a3be..c607ece418dff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
       linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
                             /*init=*/tensorDestination);
   Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+  linalgBody.addArgument(tensorType.getElementType(), loc);
 
   // Create linalg::IndexOps.
   rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
                                           /*inputs=*/ValueRange(),
                                           /*init=*/*tensorAlloc);
     Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+    linalgBody.addArgument(tensorType.getElementType(), loc);
 
     // Create linalg::IndexOps.
     rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98572f47..f4020ede4854e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
 func.func @recursive_effect(%arg : tensor<1xf32>) {
   %init = arith.constant dense<0.0> : tensor<1xf32>
   %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
-            (%in : f32) {
+            (%in : f32, %out: f32) {
               vector.print %in : f32
               linalg.yield %in : f32
             }
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
 
 // -----
 
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
   %cst = arith.constant 0.000000e+00 : f32
-  // CHECK: linalg.map
-  // CHECK-NOT: linalg.generic
-  linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
-    () {
-      linalg.yield %cst : f32
-    }
+  linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
   return
 }
 
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
 // -----
 
 func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..fabc8e610612d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
    %add = linalg.map
           ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
             linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
 func.func @map_input_mapper_arity_mismatch(
     %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
     -> tensor<64xf32> {
-  // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f64, %rhs_elem: f64) {
+      (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f64
         linalg.yield %0: f64
       }
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
       outs(%init:tensor<32xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9616a3e32a064..28d7fdc041766 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d4083af..74928920c695a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
    %add = linalg.map
       outs(%init:tensor<64xf32>)
-      () {
+      (%out: f32) {
         %0 = arith.constant 0.0: f32
         linalg.yield %0: f32
       }
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
 }
 // CHECK-LABEL: func @map_no_inputs
 //       CHECK:   linalg.map outs
-//  CHECK-NEXT:   () {
+//  CHECK-NEXT:   (%[[OUT:.*]]: f32) {
 //  CHECK-NEXT:     arith.constant
 //  CHECK-NEXT:     linalg.yield
 //  CHECK-NEXT:   }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
    linalg.map
       ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
    %abs = linalg.map
           ins(%input:tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%input_elem: f32) {
+          (%input_elem: f32, %out: f32) {
             %0 = math.absf %input_elem: f32
             linalg.yield %0: f32
           }
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
    linalg.map
       ins(%input:memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%input_elem: f32) {
+      (%input_elem: f32, %out: f32) {
         %0 = math.absf %input_elem: f32
         linalg.yield %0: f32
       }
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
             linalg.yield %0: f32
           }
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
   %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
-    (%in_1: f32, %in_2: f32) {
+    (%in_1: f32, %in_2: f32, %out: f32) {
       %1 = arith.maximumf %in_1, %in_2 : f32
       linalg.yield %in_1 : f32
     }
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
 // CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
 // CHECK:         linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) 
 // CHECK-SAME:               outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
 // CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
 // CHECK-NEXT:        linalg.yield %[[IN1]] : f32
 // CHECK-NEXT:    }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a9f22a8..704ad10130fc8 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
     %arg1: memref<64xf32>, %arg2: memref<64xf32>) {
   linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
              outs(%arg2 : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02564e35..5eb2360a29b8f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 // CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
 // CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK:         () {
+// CHECK:         (%[[INIT:.*]]: f32) {
 // CHECK:           linalg.yield %[[F]] : f32
 // CHECK:         }
 // CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @map(%lhs: memref<64xf32>,
-    %rhs: memref<64xf32>, %out: memref<64xf32>) {
+    %rhs: memref<64xf32>, %init: memref<64xf32>) {
   linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
-             outs(%out : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+             outs(%init : memref<64xf32>)
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }

@rengolin
Copy link
Member

If the map operation doesn't use the init, then bufferization will always create a new allocation.

Adding @javedabsar1 who wrote the generalization code and @matthias-springer who knows more about the bufferization part.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 11, 2025

If the map operation doesn't use the init, then bufferization will always create a new allocation.

Adding @javedabsar1 who wrote the generalization code and @matthias-springer who knows more about the bufferization part.

Hmm, good point, but that same issue would exist for ALL linalg ops that don't use init, like linalg.elementwise for example. So I don't see that as a good reason to not make linalg.map work like linalg.generic.

Also, not sure if that's an actual issue. It should allocate because the operation needs to write results to something. I just figured that was part of the semantics of linalg ops. It either reads and writes an already allocated init, or you create a new one to write to. Edit: Eh this isn't actually right. I forgot about contexts like when init is the result of tensor.empty or when it is a function arg and bufferize-function-boundaries is on, etc. But also remember that the init args are always implicitly used because that's where results go.

Furthermore, as you can see, my changes don't affect the behavior of any current tests (except for linalg-generalize-named-ops). So either this isn't tested or the thing you are worrying about is pre-existing anyway. I will double check that my changes don't affect bufferization behavior, but I'm fairly certain they don't.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 12, 2025

I confirmed that the following two examples are unaffected

func.func @map(%arg0: tensor<32x1xf32>, %arg1: tensor<32x1xf32>, %arg2: tensor<32x1xf32>) -> tensor<32x1xf32> {
    %0 = linalg.map { arith.subf } ins(%arg0, %arg2 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%arg2 : tensor<32x1xf32>)
    return %0 : tensor<32x1xf32>
}

func.func @map2(%arg0: tensor<32x1xf32>, %arg1: tensor<32x1xf32>) -> tensor<32x1xf32> {
    %arg2 = tensor.empty() : tensor<32x1xf32>
    %0 = linalg.map { arith.subf } ins(%arg0, %arg2 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%arg2 : tensor<32x1xf32>)
    return %0 : tensor<32x1xf32>
}

using --one-shot-bufferize="bufferize-function-boundaries" produces

  func.func @map(%arg0: memref<32x1xf32, strided<[?, ?], offset: ?>>, %arg1: memref<32x1xf32, strided<[?, ?], offset: ?>>, %arg2: memref<32x1xf32, strided<[?, ?], offset: ?>>) -> memref<32x1xf32, strided<[?, ?], offset: ?>> {
    linalg.map { arith.subf } ins(%arg0, %arg2 : memref<32x1xf32, strided<[?, ?], offset: ?>>, memref<32x1xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<32x1xf32, strided<[?, ?], offset: ?>>)
    return %arg2 : memref<32x1xf32, strided<[?, ?], offset: ?>>
  }
  func.func @map2(%arg0: memref<32x1xf32, strided<[?, ?], offset: ?>>, %arg1: memref<32x1xf32, strided<[?, ?], offset: ?>>) -> memref<32x1xf32> {
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x1xf32>
    linalg.map { arith.subf } ins(%arg0, %alloc : memref<32x1xf32, strided<[?, ?], offset: ?>>, memref<32x1xf32>) outs(%alloc : memref<32x1xf32>)
    %cast = memref.cast %alloc : memref<32x1xf32> to memref<32x1xf32, strided<[?, ?], offset: ?>>
    return %alloc : memref<32x1xf32>
  }

Obviously @matthias-springer should confirm or correct, as I have not looked deeply into the bufferization framework. I would doubt that it depends on linalg body block arguments, but rather on the op operands. If that's true, it makes sense that the changes here would not affect it because they are only at body block level. Edit: Moreover, memory effects haven't changed since getGenericEffectsImpl is used.

@srcarroll
Copy link
Contributor Author

pinging @matthias-springer @javedabsar1 @rengolin

@javedabsar1
Copy link
Contributor

A high level question - instead of changing linalg.map definition, would it not be less ruffling if instead you extended linalg-morph-ops with a pattern to rewrite linalg.map into generic?

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 19, 2025

A high level question - instead of changing linalg.map definition, would it not be less ruffling if instead you extended linalg-morph-ops with a pattern to rewrite linalg.map into generic?

In my opinion what you are suggesting is even more "ruffling", and I actually disagree this is ruffling anything anyway. These changes don't change the definition of map op. Nor does it have a negative impact on anything else as far as I can tell. A better question would be why have this specialization for the bbargs in the map op when all it does is make it harder to apply generic patterns that already exist? I'd like to hear one good reason to keep this specialization as is.

@javedabsar1
Copy link
Contributor

javedabsar1 commented Oct 20, 2025

A high level question - instead of changing linalg.map definition, would it not be less ruffling if instead you extended linalg-morph-ops with a pattern to rewrite linalg.map into generic?

In my opinion what you are suggesting is even more "ruffling", and I actually disagree this is ruffling anything anyway. These changes don't change the definition of map op. Nor does it have a negative impact on anything else as far as I can tell. A better question would be why have this specialization for the bbargs in the map op when all it does is make it harder to apply generic patterns that already exist? I'd like to hear one good reason to keep this specialization as is.

It was just a suggestion - I am not married to that approach. I am happy with your change if @matthias-springer is too.

Also maybe what I was suggesting is mis-understood. So let me rephrase - changing the op definition -- although in your change its not the op-definition but its lowering via body-- often meets resistance from folks because of impact on their pipeline. So what I was proposing was a e.g. populateLinalgMapToGenericPattern rewrite that takes a linalg.map and converts it to one or more core ops (incl. linalg-generic). linalg.map I have found it to be quite an odd/outlier op as other better supported ops could have been used in its place. But front-ends may generate linalg.map and so it cannot be removed easily.

I

@matthias-springer
Copy link
Member

If the map operation doesn't use the init, then bufferization will always create a new allocation.

In the current design, if you want a linalg.map to bufferize in-place, you have to select one of the "ins" operands as "init" operand. (I.e., pass the same operand as "in" and "init".)

This usually triggers a RaW in the bufferization. But there is one special case: bufferizesToElementwiseAccess. This is the case for Linalg ops and makes it such that this is not considered a conflict.

Whether you expose the "init" operand as a bbArg or not does not matter. The bufferizes-to-memwrite was already there for the init bbArg, the only thing you're possibly adding here is a bufferizes-to-memread. But due to bufferizesToElementwiseAccess, that cannot cause a RaW conflict.

Long story short: I think this PR does not pessimize bufferization.

@matthias-springer
Copy link
Member

matthias-springer commented Oct 20, 2025

Whether linalg.map should expose the "init" as a block argument: I'm not sure. I haven't work with Linalg in a while. One benefit with this PR: you don't have to pass the same operand twice to ensure that the op bufferizes in-place. E.g., you can write linalg.map ins(%a) inits(%b) { ... } if you just want to add %a and %b. (Previously, you had to pass %b twice.

@srcarroll
Copy link
Contributor Author

If the map operation doesn't use the init, then bufferization will always create a new allocation.

In the current design, if you want a linalg.map to bufferize in-place, you have to select one of the "ins" operands as "init" operand. (I.e., pass the same operand as "in" and "init".)

This usually triggers a RaW in the bufferization. But there is one special case: bufferizesToElementwiseAccess. This is the case for Linalg ops and makes it such that this is not considered a conflict.

Whether you expose the "init" operand as a bbArg or not does not matter. The bufferizes-to-memwrite was already there for the init bbArg, the only thing you're possibly adding here is a bufferizes-to-memread. But due to bufferizesToElementwiseAccess, that cannot cause a RaW conflict.

Long story short: I think this PR does not pessimize bufferization.

@matthias-springer thanks for the thorough explanation

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 20, 2025

A high level question - instead of changing linalg.map definition, would it not be less ruffling if instead you extended linalg-morph-ops with a pattern to rewrite linalg.map into generic?

In my opinion what you are suggesting is even more "ruffling", and I actually disagree this is ruffling anything anyway. These changes don't change the definition of map op. Nor does it have a negative impact on anything else as far as I can tell. A better question would be why have this specialization for the bbargs in the map op when all it does is make it harder to apply generic patterns that already exist? I'd like to hear one good reason to keep this specialization as is.

It was just a suggestion - I am not married to that approach. I am happy with your change if @matthias-springer is too.

Also maybe what I was suggesting is mis-understood. So let me rephrase - changing the op definition -- although in your change its not the op-definition but its lowering via body-- often meets resistance from folks because of impact on their pipeline. So what I was proposing was a e.g. populateLinalgMapToGenericPattern rewrite that takes a linalg.map and converts it to one or more core ops (incl. linalg-generic). linalg.map I have found it to be quite an odd/outlier op as other better supported ops could have been used in its place. But front-ends may generate linalg.map and so it cannot be removed easily.

I

As far as I can tell, the only affect on any pipeline these changes would have is whether linalg.map is converted to linalg.generic or not via linalg-generalize-named-ops.

Any frontend that is targeting the linalg dialect for conversion can choose to use linalg.map or not. If they have already chosen to use linalg.map then nothing will change for them after these changes because neither the definition nor the builder have changed.

Now on the topic of converting to linalg.generic. I can't imagine why that would be a problem for anyone. The resulting loops will be exactly the same whether linalg.map is converted straight to loops or to linalg.generic then loops. Anyone can choose to convert linalg.map to anything else before using linalg-generalize-named-ops. So I'm just not seeing a situation where this change is limiting choices.

If there is some reason beyond my understanding to avoid converting to linalg.generic, then I personally think the best solution for that is to add an option, say exclude-ops, in linalg-generalize-named-ops that allows people to specify which ops they don't want to convert.

Also I think @nicolasvasilache would agree with me that downstream uses shouldn't be a reason to avoid doing changes if they benefit the MLIR project itself.

@srcarroll
Copy link
Contributor Author

anything i can do to help move this forward?

@srcarroll
Copy link
Contributor Author

Any frontend that is targeting the linalg dialect for conversion can choose to use linalg.map or not. If they have already chosen to use linalg.map then nothing will change for them after these changes because neither the definition nor the builder have changed.

@javedabsar1 This is actually not true. What I meant that there is no change to the builder function signature. Obviously the builder has been modified to add the extra block arg. However, if the bodyBuilder function is passed then the logic for this should be unchanged for the most part. The size of the ValueRange will increase by one, but that's it. If there's any logic in downstream that depends on that size in any way, then ya there will have to be a change, but it's highly trivial. If the bodyBuilder function is not passed and the block along with its containing ops are inserted manually, then you will have to add an extra block arg, like what I did here.

Nevertheless, I think this a really trivial update for downstream users. Way more trivial than what most downstream users have to deal with when updating upstream.

I'd be happy to modify the builders further to make this as convenient as possible for others needs.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 27, 2025

@javedabsar1 @rengolin any objections to the PR? I'm about to get someone else to review and will merge if I don't hear anything.

I assume @matthias-springer doesn't have any objections, but since you are deferring linalg design to others then I'll wait to hear back from them.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without further context , I think this PR makes sense. The map operation isn't widely used, so I guess there's no big harm in giving this a go, at the very least to match its syntax / semantics to the other operations.

Let's keep an eye for weird issues and post-commit feedback. I guess we won't get more before merging.

@srcarroll
Copy link
Contributor Author

@rengolin thanks!

@srcarroll srcarroll merged commit f5e175f into llvm:main Oct 30, 2025
10 checks passed
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
This PR modifies the definition of `linalg::MapOp` so that it has the
same structure of `linalg::GenericOp` and all other linalg ops. Mainly,
it adds an `out` bbarg for the body of the op. Although the `out` arg is
never used in the body, there doesn't seem to be much benefit in
specializing the op to exclude it. In fact it only makes things more
complicated because it doesn't align with the `GenericOp` structure. For
example, `linalg-generalize-named-ops` avoided converting `linalg.map`
purely because it didn't have the structure to do so. Moreover, although
some fusion patterns are applied explicitly to `GenericOp`, we can
change them to be applied to the base `LinalgOp` which will enable
fusion for any fusion-compatible linalg op, but that requires the op
having a generic structure. So these changes will enable us to use
existing generic transformation patterns on `MapOp` that weren't
possible before. They can either be applied to `MapOp` directly or
applied after converting to `GenericOp`.
luciechoi pushed a commit to luciechoi/llvm-project that referenced this pull request Nov 1, 2025
This PR modifies the definition of `linalg::MapOp` so that it has the
same structure of `linalg::GenericOp` and all other linalg ops. Mainly,
it adds an `out` bbarg for the body of the op. Although the `out` arg is
never used in the body, there doesn't seem to be much benefit in
specializing the op to exclude it. In fact it only makes things more
complicated because it doesn't align with the `GenericOp` structure. For
example, `linalg-generalize-named-ops` avoided converting `linalg.map`
purely because it didn't have the structure to do so. Moreover, although
some fusion patterns are applied explicitly to `GenericOp`, we can
change them to be applied to the base `LinalgOp` which will enable
fusion for any fusion-compatible linalg op, but that requires the op
having a generic structure. So these changes will enable us to use
existing generic transformation patterns on `MapOp` that weren't
possible before. They can either be applied to `MapOp` directly or
applied after converting to `GenericOp`.
DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
This PR modifies the definition of `linalg::MapOp` so that it has the
same structure of `linalg::GenericOp` and all other linalg ops. Mainly,
it adds an `out` bbarg for the body of the op. Although the `out` arg is
never used in the body, there doesn't seem to be much benefit in
specializing the op to exclude it. In fact it only makes things more
complicated because it doesn't align with the `GenericOp` structure. For
example, `linalg-generalize-named-ops` avoided converting `linalg.map`
purely because it didn't have the structure to do so. Moreover, although
some fusion patterns are applied explicitly to `GenericOp`, we can
change them to be applied to the base `LinalgOp` which will enable
fusion for any fusion-compatible linalg op, but that requires the op
having a generic structure. So these changes will enable us to use
existing generic transformation patterns on `MapOp` that weren't
possible before. They can either be applied to `MapOp` directly or
applied after converting to `GenericOp`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants