-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][linalg] Genericize MapOp #162742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Genericize MapOp #162742
Conversation
|
@rengolin this is a followup on our discussion on my other PR #144922 (comment). Just want to make sure this is in the right direction and doesn't conflict with what you or others have in mind for the linalg refactor before I spend time updating all the tests. update: it wasn't that much work, so just did it anyway |
| const NamedAttrList &payloadOpAttrs, | ||
| ArrayRef<Value> operands, | ||
| bool initFirst = false) { | ||
| bool initFirst = false, bool mapInit = true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was just the option with least amount of changes but can refactor if necessary
|
|
||
| static bool canUseShortForm(Block *body, bool initFirst = false) { | ||
| static bool canUseShortForm(Block *body, bool initFirst = false, | ||
| bool mapInit = true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, this was the option with least amount of changes. will refactor if desired
|
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (srcarroll) ChangesThis PR modifies the definition of Full diff: https://github.com/llvm/llvm-project/pull/162742.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
- SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
- return getDpsInputOperands();
- }
-
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..7ccba6143637e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
- // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
- // trivially generalize a `linalg.map`, as it does not use the output as
- // region arguments in the block.
- if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+ // Bailout if `linalgOp` is already a generic.
+ if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e47a3be..c607ece418dff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/tensorDestination);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
/*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98572f47..f4020ede4854e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
func.func @recursive_effect(%arg : tensor<1xf32>) {
%init = arith.constant dense<0.0> : tensor<1xf32>
%mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
- (%in : f32) {
+ (%in : f32, %out: f32) {
vector.print %in : f32
linalg.yield %in : f32
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
// -----
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
%cst = arith.constant 0.000000e+00 : f32
- // CHECK: linalg.map
- // CHECK-NOT: linalg.generic
- linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
- () {
- linalg.yield %cst : f32
- }
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
return
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
// -----
func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..fabc8e610612d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
// expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
func.func @map_input_mapper_arity_mismatch(
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-> tensor<64xf32> {
- // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+ // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f64, %rhs_elem: f64) {
+ (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f64
linalg.yield %0: f64
}
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%init:tensor<32xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9616a3e32a064..28d7fdc041766 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d4083af..74928920c695a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
%add = linalg.map
outs(%init:tensor<64xf32>)
- () {
+ (%out: f32) {
%0 = arith.constant 0.0: f32
linalg.yield %0: f32
}
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
}
// CHECK-LABEL: func @map_no_inputs
// CHECK: linalg.map outs
-// CHECK-NEXT: () {
+// CHECK-NEXT: (%[[OUT:.*]]: f32) {
// CHECK-NEXT: arith.constant
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
linalg.map
ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
outs(%init:memref<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
%abs = linalg.map
ins(%input:tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%input_elem: f32) {
+ (%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
linalg.map
ins(%input:memref<64xf32>)
outs(%init:memref<64xf32>)
- (%input_elem: f32) {
+ (%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
linalg.yield %0: f32
}
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
- (%in_1: f32, %in_2: f32) {
+ (%in_1: f32, %in_2: f32, %out: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
}
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a9f22a8..704ad10130fc8 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
outs(%arg2 : memref<64xf32>)
- (%in: f32, %in_0: f32) {
+ (%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02564e35..5eb2360a29b8f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK: () {
+// CHECK: (%[[INIT:.*]]: f32) {
// CHECK: linalg.yield %[[F]] : f32
// CHECK: }
// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @map(%lhs: memref<64xf32>,
- %rhs: memref<64xf32>, %out: memref<64xf32>) {
+ %rhs: memref<64xf32>, %init: memref<64xf32>) {
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
- outs(%out : memref<64xf32>)
- (%in: f32, %in_0: f32) {
+ outs(%init : memref<64xf32>)
+ (%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
|
|
If the Adding @javedabsar1 who wrote the generalization code and @matthias-springer who knows more about the bufferization part. |
Hmm, good point, but that same issue would exist for ALL linalg ops that don't use init, like Also, not sure if that's an actual issue. It should allocate because the operation needs to write results to something. I just figured that was part of the semantics of linalg ops. It either reads and writes an already allocated init, or you create a new one to write to. Edit: Eh this isn't actually right. I forgot about contexts like when init is the result of Furthermore, as you can see, my changes don't affect the behavior of any current tests (except for |
|
I confirmed that the following two examples are unaffected using Obviously @matthias-springer should confirm or correct, as I have not looked deeply into the bufferization framework. I would doubt that it depends on linalg body block arguments, but rather on the op operands. If that's true, it makes sense that the changes here would not affect it because they are only at body block level. Edit: Moreover, memory effects haven't changed since |
|
pinging @matthias-springer @javedabsar1 @rengolin |
|
A high level question - instead of changing linalg.map definition, would it not be less ruffling if instead you extended linalg-morph-ops with a pattern to rewrite linalg.map into generic? |
In my opinion what you are suggesting is even more "ruffling", and I actually disagree this is ruffling anything anyway. These changes don't change the definition of map op. Nor does it have a negative impact on anything else as far as I can tell. A better question would be why have this specialization for the bbargs in the map op when all it does is make it harder to apply generic patterns that already exist? I'd like to hear one good reason to keep this specialization as is. |
It was just a suggestion - I am not married to that approach. I am happy with your change if @matthias-springer is too. Also maybe what I was suggesting is mis-understood. So let me rephrase - changing the op definition -- although in your change its not the op-definition but its lowering via body-- often meets resistance from folks because of impact on their pipeline. So what I was proposing was a e.g. I |
In the current design, if you want a This usually triggers a RaW in the bufferization. But there is one special case: Whether you expose the "init" operand as a bbArg or not does not matter. The bufferizes-to-memwrite was already there for the init bbArg, the only thing you're possibly adding here is a bufferizes-to-memread. But due to Long story short: I think this PR does not pessimize bufferization. |
|
Whether |
@matthias-springer thanks for the thorough explanation |
As far as I can tell, the only affect on any pipeline these changes would have is whether Any frontend that is targeting the linalg dialect for conversion can choose to use Now on the topic of converting to If there is some reason beyond my understanding to avoid converting to Also I think @nicolasvasilache would agree with me that downstream uses shouldn't be a reason to avoid doing changes if they benefit the MLIR project itself. |
|
anything i can do to help move this forward? |
@javedabsar1 This is actually not true. What I meant that there is no change to the builder function signature. Obviously the builder has been modified to add the extra block arg. However, if the Nevertheless, I think this a really trivial update for downstream users. Way more trivial than what most downstream users have to deal with when updating upstream. I'd be happy to modify the builders further to make this as convenient as possible for others needs. |
|
@javedabsar1 @rengolin any objections to the PR? I'm about to get someone else to review and will merge if I don't hear anything. I assume @matthias-springer doesn't have any objections, but since you are deferring linalg design to others then I'll wait to hear back from them. |
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without further context , I think this PR makes sense. The map operation isn't widely used, so I guess there's no big harm in giving this a go, at the very least to match its syntax / semantics to the other operations.
Let's keep an eye for weird issues and post-commit feedback. I guess we won't get more before merging.
|
@rengolin thanks! |
This PR modifies the definition of `linalg::MapOp` so that it has the same structure of `linalg::GenericOp` and all other linalg ops. Mainly, it adds an `out` bbarg for the body of the op. Although the `out` arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the `GenericOp` structure. For example, `linalg-generalize-named-ops` avoided converting `linalg.map` purely because it didn't have the structure to do so. Moreover, although some fusion patterns are applied explicitly to `GenericOp`, we can change them to be applied to the base `LinalgOp` which will enable fusion for any fusion-compatible linalg op, but that requires the op having a generic structure. So these changes will enable us to use existing generic transformation patterns on `MapOp` that weren't possible before. They can either be applied to `MapOp` directly or applied after converting to `GenericOp`.
This PR modifies the definition of `linalg::MapOp` so that it has the same structure of `linalg::GenericOp` and all other linalg ops. Mainly, it adds an `out` bbarg for the body of the op. Although the `out` arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the `GenericOp` structure. For example, `linalg-generalize-named-ops` avoided converting `linalg.map` purely because it didn't have the structure to do so. Moreover, although some fusion patterns are applied explicitly to `GenericOp`, we can change them to be applied to the base `LinalgOp` which will enable fusion for any fusion-compatible linalg op, but that requires the op having a generic structure. So these changes will enable us to use existing generic transformation patterns on `MapOp` that weren't possible before. They can either be applied to `MapOp` directly or applied after converting to `GenericOp`.
This PR modifies the definition of `linalg::MapOp` so that it has the same structure of `linalg::GenericOp` and all other linalg ops. Mainly, it adds an `out` bbarg for the body of the op. Although the `out` arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the `GenericOp` structure. For example, `linalg-generalize-named-ops` avoided converting `linalg.map` purely because it didn't have the structure to do so. Moreover, although some fusion patterns are applied explicitly to `GenericOp`, we can change them to be applied to the base `LinalgOp` which will enable fusion for any fusion-compatible linalg op, but that requires the op having a generic structure. So these changes will enable us to use existing generic transformation patterns on `MapOp` that weren't possible before. They can either be applied to `MapOp` directly or applied after converting to `GenericOp`.
This PR modifies the definition of
linalg::MapOpso that it has the same structure oflinalg::GenericOpand all other linalg ops. Mainly, it adds anoutbbarg for the body of the op. Although theoutarg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with theGenericOpstructure. For example,linalg-generalize-named-opsavoided convertinglinalg.mappurely because it didn't have the structure to do so. Moreover, although some fusion patterns are applied explicitly toGenericOp, we can change them to be applied to the baseLinalgOpwhich will enable fusion for any fusion-compatible linalg op, but that requires the op having a generic structure. So these changes will enable us to use existing generic transformation patterns onMapOpthat weren't possible before. They can either be applied toMapOpdirectly or applied after converting toGenericOp.