Skip to content

Commit

Permalink
transformations: update stencil lowering to properly accomodate buffe…
Browse files Browse the repository at this point in the history
…rized stencils. (#3016)

- Lower `stencil.alloc`
- Include subviews for reference-semantics argument of `stencil.apply`
(Those subviews are `load`/`store`'s jobs on value-semantics, so just
golfing for minimal changes while keeping one pass, for now!)

---------

Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
PapyChacal and superlopuh authored Aug 19, 2024
1 parent f0684ae commit 3679472
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 13 deletions.
24 changes: 13 additions & 11 deletions tests/filecheck/transforms/convert-stencil-to-ll-mlir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -821,20 +821,22 @@ func.func @buffered_combine(%115 : !stencil.field<?x?xf64>) {
func.return
}

// CHECK: func.func @stencil_copy_bufferized(%0 : memref<72x72x72xf64>, %1 : memref<72x72x72xf64>) {
// CHECK-NEXT: %2 = arith.constant 0 : index
// CHECK-NEXT: %3 = arith.constant 0 : index
// CHECK: func.func @stencil_copy_bufferized(%0 : memref<72x72x72xf64>, %1 : memref<72x72x72xf64>) {
// CHECK-NEXT: %2 = memref.subview %1[4, 4, 4] [72, 72, 72] [1, 1, 1] : memref<72x72x72xf64> to memref<72x72x72xf64, strided<[5184, 72, 1], offset: 21028>>
// CHECK-NEXT: %3 = memref.subview %0[4, 4, 4] [72, 72, 72] [1, 1, 1] : memref<72x72x72xf64> to memref<72x72x72xf64, strided<[5184, 72, 1], offset: 21028>>
// CHECK-NEXT: %4 = arith.constant 0 : index
// CHECK-NEXT: %5 = arith.constant 1 : index
// CHECK-NEXT: %6 = arith.constant 1 : index
// CHECK-NEXT: %5 = arith.constant 0 : index
// CHECK-NEXT: %6 = arith.constant 0 : index
// CHECK-NEXT: %7 = arith.constant 1 : index
// CHECK-NEXT: %8 = arith.constant 64 : index
// CHECK-NEXT: %9 = arith.constant 64 : index
// CHECK-NEXT: %8 = arith.constant 1 : index
// CHECK-NEXT: %9 = arith.constant 1 : index
// CHECK-NEXT: %10 = arith.constant 64 : index
// CHECK-NEXT: "scf.parallel"(%2, %3, %4, %8, %9, %10, %5, %6, %7) <{"operandSegmentSizes" = array<i32: 3, 3, 3, 0>}> ({
// CHECK-NEXT: ^0(%11 : index, %12 : index, %13 : index):
// CHECK-NEXT: %14 = memref.load %0[%11, %12, %13] : memref<72x72x72xf64>
// CHECK-NEXT: memref.store %14, %1[%11, %12, %13] : memref<72x72x72xf64>
// CHECK-NEXT: %11 = arith.constant 64 : index
// CHECK-NEXT: %12 = arith.constant 64 : index
// CHECK-NEXT: "scf.parallel"(%4, %5, %6, %10, %11, %12, %7, %8, %9) <{"operandSegmentSizes" = array<i32: 3, 3, 3, 0>}> ({
// CHECK-NEXT: ^0(%13 : index, %14 : index, %15 : index):
// CHECK-NEXT: %16 = memref.load %3[%13, %14, %15] : memref<72x72x72xf64, strided<[5184, 72, 1], offset: 21028>>
// CHECK-NEXT: memref.store %16, %2[%13, %14, %15] : memref<72x72x72xf64, strided<[5184, 72, 1], offset: 21028>>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }) : (index, index, index, index, index, index, index, index, index) -> ()
// CHECK-NEXT: func.return
Expand Down
55 changes: 53 additions & 2 deletions xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from itertools import product
from math import prod
from typing import TypeVar
from typing import TypeVar, cast
from warnings import warn

from xdsl.context import MLContext
Expand All @@ -12,6 +12,7 @@
)
from xdsl.dialects.stencil import (
AccessOp,
AllocOp,
ApplyOp,
BufferOp,
CastOp,
Expand Down Expand Up @@ -151,6 +152,10 @@ def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /):
target = self.return_target[apply][j]
else:
target = apply.dest[j]
rewriter.insert_op(
subview := field_subview(target), InsertPoint.before(apply)
)
target = subview

unroll = op.unroll
if unroll is None:
Expand Down Expand Up @@ -314,6 +319,50 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op([], [alloc.memref])


def field_subview(field: SSAValue):
assert isa(field_type := field.type, FieldType[Attribute])
assert isinstance(bounds := field_type.bounds, StencilBoundsAttr)
offsets = [i for i in -bounds.lb]
sizes = [i for i in field_type.get_shape()]
strides = [1] * len(sizes)

return memref.Subview.from_static_parameters(
field, StencilToMemRefType(field_type), offsets, sizes, strides
)


class AllocOpToMemref(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AllocOp, rewriter: PatternRewriter, /):
alloc = memref.Alloc(
[], [], StencilToMemRefType(cast(StencilType[Attribute], op.field.type))
)
rewriter.replace_matched_op(alloc)


@dataclass
class ApplyOpFieldSubviews(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
args = [
field_subview(arg) if isinstance(arg.type, FieldType) else arg
for arg in op.args
]
if args == list(op.args):
return

new_apply = ApplyOp.create(
operands=[SSAValue.get(arg) for arg in args] + list(op.dest),
result_types=[r.type for r in op.res],
regions=[op.detach_region(0)],
attributes=op.attributes,
properties=op.properties,
)
rewriter.replace_matched_op(
[*(arg for arg in args if isinstance(arg, Operation)), new_apply]
)


@dataclass
class ApplyOpToParallel(RewritePattern):
return_targets: dict[ApplyOp, list[SSAValue | None]]
Expand Down Expand Up @@ -392,7 +441,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
rewriter.replace_op(index, ops, res)

# Get the maybe updated results
new_results = self.return_targets[op]
new_results = self.return_targets[op] if op in self.return_targets else []
# Replace with the loop and necessary constants.
assert isa(boilerplate_ops, list[Operation])
rewriter.insert_op_before_matched_op([*boilerplate_ops, p])
Expand Down Expand Up @@ -607,6 +656,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
the_one_pass = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ApplyOpFieldSubviews(),
ApplyOpToParallel(return_targets),
BufferOpToMemref(return_targets),
StencilStoreToSubview(return_targets),
Expand All @@ -616,6 +666,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
ReturnOpToMemref(return_targets),
TrivialExternalLoadOpCleanup(),
TrivialExternalStoreOpCleanup(),
AllocOpToMemref(),
]
),
apply_recursively=True,
Expand Down

0 comments on commit 3679472

Please sign in to comment.