Skip to content

Commit

Permalink
transformations: (csl-stencil-bufferize) Fold csl_stencil.access th…
Browse files Browse the repository at this point in the history
…at have no effect (#3084)

The `csl-stencil-bufferize` pass lowers `stencil.field` and `tensor`
types to `memref`.

The type conversion pass creates the scenario that some
`csl_stencil.access` ops are equal input and output types, for instance,
`(memref<512xf32>) -> memref<512xf32>`. This only happens for ops
accessing own data. In this case, the access op has no effect and can
safely be folded away.

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Aug 27, 2024
1 parent 03d5322 commit 1e488d7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
21 changes: 10 additions & 11 deletions tests/filecheck/transforms/csl_stencil_bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ builtin.module {
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%20 : memref<512xf32>, %21 : memref<510xf32>):
// CHECK-NEXT: %22 = bufferization.to_tensor %21 restrict writable : memref<510xf32>
// CHECK-NEXT: %23 = csl_stencil.access %20[0, 0] : memref<512xf32>
// CHECK-NEXT: %24 = bufferization.to_tensor %23 restrict : memref<512xf32>
// CHECK-NEXT: %25 = arith.constant dense<1.666600e-01> : memref<510xf32>
// CHECK-NEXT: %26 = bufferization.to_tensor %25 restrict : memref<510xf32>
// CHECK-NEXT: %27 = "tensor.extract_slice"(%24) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %28 = "tensor.extract_slice"(%24) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %29 = linalg.add ins(%22, %28 : tensor<510xf32>, tensor<510xf32>) outs(%22 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %30 = linalg.add ins(%29, %27 : tensor<510xf32>, tensor<510xf32>) outs(%29 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %31 = linalg.mul ins(%30, %26 : tensor<510xf32>, tensor<510xf32>) outs(%30 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %32 = bufferization.to_memref %31 : memref<510xf32>
// CHECK-NEXT: csl_stencil.yield %32 : memref<510xf32>
// CHECK-NEXT: %23 = bufferization.to_tensor %20 restrict : memref<512xf32>
// CHECK-NEXT: %24 = arith.constant dense<1.666600e-01> : memref<510xf32>
// CHECK-NEXT: %25 = bufferization.to_tensor %24 restrict : memref<510xf32>
// CHECK-NEXT: %26 = "tensor.extract_slice"(%23) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %27 = "tensor.extract_slice"(%23) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %28 = linalg.add ins(%22, %27 : tensor<510xf32>, tensor<510xf32>) outs(%22 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %29 = linalg.add ins(%28, %26 : tensor<510xf32>, tensor<510xf32>) outs(%28 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %30 = linalg.mul ins(%29, %25 : tensor<510xf32>, tensor<510xf32>) outs(%29 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %31 = bufferization.to_memref %30 : memref<510xf32>
// CHECK-NEXT: csl_stencil.yield %31 : memref<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: func.return
// CHECK-NEXT: }
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,15 +513,15 @@ def parse(cls, parser: Parser):
return cls.build(
operands=[temp],
result_types=[
TensorType(res_type.element_type, res_type.get_shape()[1:])
TensorType(res_type.element_type, res_type.get_shape()[-1:])
],
properties=props,
)
elif isattr(res_type, base(AnyMemRefType)):
return cls.build(
operands=[temp],
result_types=[
memref.MemRefType(res_type.element_type, res_type.get_shape()[1:])
memref.MemRefType(res_type.element_type, res_type.get_shape()[-1:])
],
properties=props,
)
Expand Down
17 changes: 15 additions & 2 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,18 +249,31 @@ def _build_extract_slice(

@dataclass(frozen=True)
class AccessOpBufferize(RewritePattern):
"""Bufferizes AccessOp."""
"""
Bufferizes AccessOp.
The type conversion pass creates the scenario that some `csl_stencil.access` ops are equal input and output types,
for instance, `(memref<512xf32>) -> memref<512xf32>`. This only happens for ops accessing own data. In this case,
the access op has no effect and can safely be folded away.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter, /):
if not isa(op.result.type, TensorType[Attribute]):
return
r_type = tensor_to_memref_type(op.result.type)

# accesses to own data that (after bufferization) have the same input and output type can be safely folded away
if op.op.type == r_type and all(o == 0 for o in op.offset):
rewriter.replace_matched_op(to_tensor_op(op.op))
return

rewriter.replace_matched_op(
[
access := csl_stencil.AccessOp(
op.op,
op.offset,
tensor_to_memref_type(op.result.type),
r_type,
op.offset_mapping,
),
to_tensor_op(access.result),
Expand Down

0 comments on commit 1e488d7

Please sign in to comment.