Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 37 additions & 35 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {

```mlir
%new = memref.realloc %old : memref<64xf32> to memref<124xf32>
%4 = memref.load %new[%index] // ok
%5 = memref.load %old[%index] // undefined behavior
%4 = memref.load %new[%index] : memref<124xf32> // ok
%5 = memref.load %old[%index] : memref<64xf32> // undefined behavior
```
}];

Expand Down Expand Up @@ -445,9 +445,10 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
operation:

```mlir
%result = memref.alloca_scope {
%result = memref.alloca_scope -> f32 {
%value = arith.constant 1.0 : f32
...
memref.alloca_scope.return %value
memref.alloca_scope.return %value : f32
}
```

Expand Down Expand Up @@ -478,7 +479,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
to indicate which values are going to be returned. For example:

```mlir
memref.alloca_scope.return %value
memref.alloca_scope.return %value : f32
```
}];

Expand Down Expand Up @@ -543,11 +544,11 @@ def MemRef_CastOp : MemRef_Op<"cast", [
Example:

```mlir
Cast to concrete shape.
%4 = memref.cast %1 : memref<*xf32> to memref<4x?xf32>
// Cast to concrete shape.
%4 = memref.cast %1 : memref<*xf32> to memref<4x?xf32>

Erase rank information.
%5 = memref.cast %1 : memref<4x?xf32> to memref<*xf32>
// Erase rank information.
%5 = memref.cast %1 : memref<4x?xf32> to memref<*xf32>
```
}];

Expand Down Expand Up @@ -613,8 +614,8 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
Example:

```mlir
%0 = memref.alloc() : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1), 1>>
memref.dealloc %0 : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1), 1>>
%0 = memref.alloc() : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
memref.dealloc %0 : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
```
}];

Expand Down Expand Up @@ -728,22 +729,22 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
space 1 at indices [%k, %l], would be specified as follows:

```mlir
%num_elements = arith.constant 256
%num_elements = arith.constant 256 : index
Copy link
Contributor Author

@FruitClover FruitClover Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: index is here cause of Index:$numElements

docs.mlir:120:50: error: use of value '%num_elements' expects different type than prior uses: 'index' vs 'i64'
    memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
                                                 ^
docs.mlir:117:5: note: prior use here
    %num_elements = arith.constant 256
    ^

%idx = arith.constant 0 : index
%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
memref<40 x 128 x f32>, affine_map<(d0) -> (d0)>, 0>,
memref<2 x 1024 x f32>, affine_map<(d0) -> (d0)>, 1>,
memref<1 x i32>, affine_map<(d0) -> (d0)>, 2>
%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
memref<40 x 128 x f32, affine_map<(d0, d1) -> (d0, d1)>, 0>,
memref<2 x 1024 x f32, affine_map<(d0, d1) -> (d0, d1)>, 1>,
memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
```

If %stride and %num_elt_per_stride are specified, the DMA is expected to
transfer %num_elt_per_stride elements every %stride elements apart from
memory space 0 until %num_elements are transferred.

```mlir
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
%num_elt_per_stride :
memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
%num_elt_per_stride :
```

* TODO: add additional operands to allow source and destination striding, and
Expand Down Expand Up @@ -891,10 +892,10 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
Example:

```mlir
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
memref<2048 x f32>, affine_map<(d0) -> (d0)>, 0>,
memref<256 x f32>, affine_map<(d0) -> (d0)>, 1>
memref<1 x i32>, affine_map<(d0) -> (d0)>, 2>
memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
memref<2048 x f32, affine_map<(d0) -> (d0)>, 0>,
memref<256 x f32, affine_map<(d0) -> (d0)>, 1>,
memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
...
...
dma_wait %tag[%index], %num_elements : memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
Expand Down Expand Up @@ -1004,16 +1005,16 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [

```mlir
%base, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %memref :
memref<10x?xf32>, index, index, index, index, index
memref.extract_strided_metadata %memref : memref<10x?xf32>
-> memref<f32>, index, index, index, index, index

// After folding, the type of %m2 can be memref<10x?xf32> and further
// folded to %memref.
%m2 = memref.reinterpret_cast %base to
offset: [%offset],
sizes: [%sizes#0, %sizes#1],
strides: [%strides#0, %strides#1]
: memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
: memref<f32> to memref<?x?xf32, strided<[?, ?], offset:?>>
```
}];

Expand Down Expand Up @@ -1182,10 +1183,10 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {

```mlir
// Private variable with an initial value.
memref.global "private" @x : memref<2xf32> = dense<0.0,2.0>
memref.global "private" @x : memref<2xf32> = dense<[0.0, 2.0]>

// Private variable with an initial value and an alignment (power of 2).
memref.global "private" @x : memref<2xf32> = dense<0.0,2.0> {alignment = 64}
memref.global "private" @x : memref<2xf32> = dense<[0.0, 2.0]> {alignment = 64}

// Declaration of an external variable.
memref.global "private" @y : memref<4xi32>
Expand All @@ -1194,7 +1195,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
memref.global @z : memref<3xf16> = uninitialized

// Externally visible constant variable.
memref.global constant @c : memref<2xi32> = dense<1, 4>
memref.global constant @c : memref<2xi32> = dense<[1, 4]>
```
}];

Expand Down Expand Up @@ -1555,7 +1556,8 @@ def MemRef_ReinterpretCastOp
%dst = memref.reinterpret_cast %src to
offset: [%offset],
sizes: [%sizes],
strides: [%strides]
strides: [%strides] :
memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
```
means that `%dst`'s descriptor will be:
```mlir
Expand Down Expand Up @@ -1695,12 +1697,12 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
```mlir
// Reshape statically-shaped memref.
%dst = memref.reshape %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
: (memref<4x1xf32>, memref<1xi32>) -> memref<4xf32>
%dst0 = memref.reshape %src(%shape0)
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
: (memref<4x1xf32>, memref<2xi32>) -> memref<2x2xf32>
// Flatten unranked memref.
%dst = memref.reshape %src(%shape)
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
```

b. Source type is ranked or unranked. Shape argument has dynamic size.
Expand All @@ -1709,10 +1711,10 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
```mlir
// Reshape dynamically-shaped 1D memref.
%dst = memref.reshape %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
: (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
// Reshape unranked memref.
%dst = memref.reshape %src(%shape)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
```
}];

Expand Down