Skip to content

Commit

Permalink
dialects: (bufferization) fix materialize_in_destination (#3495)
Browse files Browse the repository at this point in the history
Fixes #3489 

In order to make `materialize_in_destination` compatible with mlir, it
needs to be able to take a memref type as the destination, and have an
optional result.
  • Loading branch information
alexarice authored Dec 4, 2024
1 parent 24ab775 commit f33fbde
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 17 deletions.
24 changes: 13 additions & 11 deletions tests/filecheck/dialects/bufferization/bufferization_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP

// CHECK: builtin.module {

// CHECK-NEXT: %i0, %i1 = "test.op"() : () -> (index, index)
// CHECK-NEXT: %t0 = "test.op"() : () -> tensor<10x20x30xf64>
// CHECK: %i0, %i1 = "test.op"() : () -> (index, index)
// CHECK-NEXT: %t0, %m0 = "test.op"() : () -> (tensor<10x20x30xf64>, memref<10x20x30xf64>)
%i0, %i1 = "test.op"() : () -> (index, index)
%t0 = "test.op"() : () -> tensor<10x20x30xf64>
%t0, %m0 = "test.op"() : () -> (tensor<10x20x30xf64>, memref<10x20x30xf64>)

// CHECK-NEXT: %t1 = bufferization.alloc_tensor(%i0, %i1) {"hello" = "world"} : tensor<10x20x?x?xf64>
// CHECK-NEXT: %t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64>
// CHECK-NEXT: %t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1 : tensor<10x20x?x?xf64>
// CHECK-NEXT: %t1 = bufferization.alloc_tensor(%i0, %i1) {"hello" = "world"} : tensor<10x20x?x?xf64>
// CHECK-NEXT: %t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64>
// CHECK-NEXT: %t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1 : tensor<10x20x?x?xf64>
%t1 = bufferization.alloc_tensor(%i0, %i1) {"hello"="world"}: tensor<10x20x?x?xf64>
%t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64>
%t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1: tensor<10x20x?x?xf64>
Expand All @@ -20,15 +18,19 @@
%m = "test.op"() : () -> memref<30x20x10xf32>
%c = bufferization.clone %m : memref<30x20x10xf32> to memref<30x20x10xf32>

// CHECK-NEXT: }

// CHECK-NEXT: bufferization.materialize_in_destination %t0 in writable %m0 : (tensor<10x20x30xf64>, memref<10x20x30xf64>) -> ()
// CHECK-NEXT: bufferization.materialize_in_destination %t0 in %t0 : (tensor<10x20x30xf64>, tensor<10x20x30xf64>) -> tensor<10x20x30xf64>
bufferization.materialize_in_destination %t0 in writable %m0 : (tensor<10x20x30xf64>, memref<10x20x30xf64>) -> ()
bufferization.materialize_in_destination %t0 in %t0 : (tensor<10x20x30xf64>, tensor<10x20x30xf64>) -> tensor<10x20x30xf64>

// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: %i0, %i1 = "test.op"() : () -> (index, index)
// CHECK-GENERIC-NEXT: %t0 = "test.op"() : () -> tensor<10x20x30xf64>
// CHECK-GENERIC-NEXT: %t0, %m0 = "test.op"() : () -> (tensor<10x20x30xf64>, memref<10x20x30xf64>)
// CHECK-GENERIC-NEXT: %t1 = "bufferization.alloc_tensor"(%i0, %i1) <{"operandSegmentSizes" = array<i32: 2, 0, 0>}> {"hello" = "world"} : (index, index) -> tensor<10x20x?x?xf64>
// CHECK-GENERIC-NEXT: %t2 = "bufferization.alloc_tensor"(%t0) <{"operandSegmentSizes" = array<i32: 0, 1, 0>}> : (tensor<10x20x30xf64>) -> tensor<10x20x30xf64>
// CHECK-GENERIC-NEXT: %t3 = "bufferization.alloc_tensor"(%i0, %i1, %i1) <{"operandSegmentSizes" = array<i32: 2, 0, 1>}> : (index, index, index) -> tensor<10x20x?x?xf64>
// CHECK-GENERIC-NEXT: %m = "test.op"() : () -> memref<30x20x10xf32>
// CHECK-GENERIC-NEXT: %c = "bufferization.clone"(%m) : (memref<30x20x10xf32>) -> memref<30x20x10xf32>
// CHECK-GENERIC-NEXT: "bufferization.materialize_in_destination"(%t0, %m0) <{"writable"}> : (tensor<10x20x30xf64>, memref<10x20x30xf64>) -> ()
// CHECK-GENERIC-NEXT: %0 = "bufferization.materialize_in_destination"(%t0, %t0) : (tensor<10x20x30xf64>, tensor<10x20x30xf64>) -> tensor<10x20x30xf64>
// CHECK-GENERIC-NEXT: }) : () -> ()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

%tensor1 = "test.op"() : () -> tensor<2x2xf64>
%tensor2 = "test.op"() : () -> tensor<2x2xf64>
%memref = "test.op"() : () -> memref<2x2xf64>
bufferization.materialize_in_destination %tensor1 in writable %memref : (tensor<2x2xf64>, memref<2x2xf64>) -> ()
%b = bufferization.materialize_in_destination %tensor1 in %tensor2 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64>


Expand All @@ -20,5 +22,7 @@
// CHECK-NEXT: %4 = bufferization.clone %1 : memref<30x20x10xf32> to memref<30x20x10xf32>
// CHECK-NEXT: %5 = "test.op"() : () -> tensor<2x2xf64>
// CHECK-NEXT: %6 = "test.op"() : () -> tensor<2x2xf64>
// CHECK-NEXT: %7 = bufferization.materialize_in_destination %5 in %6 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64>
// CHECK-NEXT: %7 = "test.op"() : () -> memref<2x2xf64>
// CHECK-NEXT: bufferization.materialize_in_destination %5 in writable %7 : (tensor<2x2xf64>, memref<2x2xf64>) -> ()
// CHECK-NEXT: %8 = bufferization.materialize_in_destination %5 in %6 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64>
// CHECK-NEXT: }
11 changes: 6 additions & 5 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
operand_def,
opt_operand_def,
opt_prop_def,
opt_result_def,
result_def,
var_operand_def,
)
Expand Down Expand Up @@ -208,15 +209,15 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestinationOp(IRDLOperation):
name = "bufferization.materialize_in_destination"

T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)
source = operand_def(T)
dest = operand_def(T)
result = result_def(T)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
source = operand_def(TensorFromMemrefConstraint(T))
dest = operand_def(T | TensorFromMemrefConstraint(T))
result = opt_result_def(TensorFromMemrefConstraint(T))

restrict = opt_prop_def(UnitAttr)
writable = opt_prop_def(UnitAttr)

assembly_format = "$source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest attr-dict `:` `(` type($source) `,` type($dest) `)` `->` type($result)"
assembly_format = "$source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest attr-dict `:` functional-type(operands, results)"


Bufferization = Dialect(
Expand Down

0 comments on commit f33fbde

Please sign in to comment.