Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (csl-stencil-bufferize) Inject iter_arg into linalg compute #3033

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions tests/filecheck/transforms/csl_stencil_bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ builtin.module {
%5 = csl_stencil.access %1[-1, 0] : tensor<4x255xf32>
%6 = csl_stencil.access %1[0, 1] : tensor<4x255xf32>
%7 = csl_stencil.access %1[0, -1] : tensor<4x255xf32>
%8 = arith.addf %7, %6 : tensor<255xf32>
%9 = arith.addf %8, %5 : tensor<255xf32>
%10 = arith.addf %9, %4 : tensor<255xf32>
%8 = linalg.add ins(%7, %6 : tensor<255xf32>, tensor<255xf32>) outs(%7 : tensor<255xf32>) -> tensor<255xf32>
%9 = linalg.add ins(%8, %5 : tensor<255xf32>, tensor<255xf32>) outs(%8 : tensor<255xf32>) -> tensor<255xf32>
%10 = linalg.add ins(%9, %4 : tensor<255xf32>, tensor<255xf32>) outs(%9 : tensor<255xf32>) -> tensor<255xf32>
%11 = "tensor.insert_slice"(%10, %3, %2) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
csl_stencil.yield %11 : tensor<510xf32>
}, {
Expand All @@ -20,9 +20,9 @@ builtin.module {
%15 = arith.constant dense<1.666600e-01> : tensor<510xf32>
%16 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%17 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%18 = arith.addf %13, %17 : tensor<510xf32>
%19 = arith.addf %18, %16 : tensor<510xf32>
%20 = arith.mulf %19, %15 : tensor<510xf32>
%18 = linalg.add ins(%13, %17 : tensor<510xf32>, tensor<510xf32>) outs(%13 : tensor<510xf32>) -> tensor<510xf32>
%19 = linalg.add ins(%18, %16 : tensor<510xf32>, tensor<510xf32>) outs(%18 : tensor<510xf32>) -> tensor<510xf32>
%20 = linalg.mul ins(%19, %15 : tensor<510xf32>, tensor<510xf32>) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
csl_stencil.yield %20 : tensor<510xf32>
}) to <[0, 0], [1, 1]>
func.return
Expand All @@ -37,18 +37,18 @@ builtin.module {
// CHECK-NEXT: csl_stencil.apply(%a : memref<512xf32>, %1 : memref<510xf32>) outs (%b : memref<512xf32>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 1>}> ({
// CHECK-NEXT: ^0(%2 : memref<4x255xf32>, %3 : index, %4 : memref<510xf32>):
// CHECK-NEXT: %5 = bufferization.to_tensor %4 restrict writable : memref<510xf32>
// CHECK-NEXT: %6 = "tensor.extract_slice"(%5, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0>}> : (tensor<510xf32>, index) -> tensor<255xf32>
// CHECK-NEXT: %7 = csl_stencil.access %2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %8 = bufferization.to_tensor %7 restrict : memref<255xf32>
// CHECK-NEXT: %9 = csl_stencil.access %2[-1, 0] : memref<4x255xf32>
// CHECK-NEXT: %10 = bufferization.to_tensor %9 restrict : memref<255xf32>
// CHECK-NEXT: %11 = csl_stencil.access %2[0, 1] : memref<4x255xf32>
// CHECK-NEXT: %12 = bufferization.to_tensor %11 restrict : memref<255xf32>
// CHECK-NEXT: %13 = csl_stencil.access %2[0, -1] : memref<4x255xf32>
// CHECK-NEXT: %14 = bufferization.to_tensor %13 restrict : memref<255xf32>
// CHECK-NEXT: %15 = arith.addf %14, %12 : tensor<255xf32>
// CHECK-NEXT: %16 = arith.addf %15, %10 : tensor<255xf32>
// CHECK-NEXT: %17 = arith.addf %16, %8 : tensor<255xf32>
// CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %7 = bufferization.to_tensor %6 restrict : memref<255xf32>
// CHECK-NEXT: %8 = csl_stencil.access %2[-1, 0] : memref<4x255xf32>
// CHECK-NEXT: %9 = bufferization.to_tensor %8 restrict : memref<255xf32>
// CHECK-NEXT: %10 = csl_stencil.access %2[0, 1] : memref<4x255xf32>
// CHECK-NEXT: %11 = bufferization.to_tensor %10 restrict : memref<255xf32>
// CHECK-NEXT: %12 = csl_stencil.access %2[0, -1] : memref<4x255xf32>
// CHECK-NEXT: %13 = bufferization.to_tensor %12 restrict : memref<255xf32>
// CHECK-NEXT: %14 = "tensor.extract_slice"(%5, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0>}> : (tensor<510xf32>, index) -> tensor<255xf32>
// CHECK-NEXT: %15 = linalg.add ins(%13, %11 : tensor<255xf32>, tensor<255xf32>) outs(%14 : tensor<255xf32>) -> tensor<255xf32>
// CHECK-NEXT: %16 = linalg.add ins(%15, %9 : tensor<255xf32>, tensor<255xf32>) outs(%15 : tensor<255xf32>) -> tensor<255xf32>
// CHECK-NEXT: %17 = linalg.add ins(%16, %7 : tensor<255xf32>, tensor<255xf32>) outs(%16 : tensor<255xf32>) -> tensor<255xf32>
// CHECK-NEXT: %18 = "tensor.insert_slice"(%17, %5, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: %19 = bufferization.to_memref %18 : memref<510xf32>
// CHECK-NEXT: csl_stencil.yield %19 : memref<510xf32>
Expand All @@ -61,9 +61,9 @@ builtin.module {
// CHECK-NEXT: %26 = bufferization.to_tensor %25 restrict : memref<510xf32>
// CHECK-NEXT: %27 = "tensor.extract_slice"(%24) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %28 = "tensor.extract_slice"(%24) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %29 = arith.addf %22, %28 : tensor<510xf32>
// CHECK-NEXT: %30 = arith.addf %29, %27 : tensor<510xf32>
// CHECK-NEXT: %31 = arith.mulf %30, %26 : tensor<510xf32>
// CHECK-NEXT: %29 = linalg.add ins(%22, %28 : tensor<510xf32>, tensor<510xf32>) outs(%22 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %30 = linalg.add ins(%29, %27 : tensor<510xf32>, tensor<510xf32>) outs(%29 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %31 = linalg.mul ins(%30, %26 : tensor<510xf32>, tensor<510xf32>) outs(%30 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %32 = bufferization.to_memref %31 : memref<510xf32>
// CHECK-NEXT: csl_stencil.yield %32 : memref<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
Expand Down
76 changes: 68 additions & 8 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import arith, bufferization, func, memref, stencil, tensor
from xdsl.dialects import arith, bufferization, func, linalg, memref, stencil, tensor
from xdsl.dialects.builtin import (
DenseArrayBase,
DenseIntOrFPElementsAttr,
Expand Down Expand Up @@ -114,12 +114,12 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
t := to_tensor_op(arg, writable=idx == 2),
InsertPoint.at_end(buf_apply_op.chunk_reduce.block),
)
if idx == 2:
offset_arg = buf_apply_op.chunk_reduce.block.args[1]
rewriter.insert_op(
self._build_extract_slice(op, t, offset_arg),
InsertPoint.at_end(buf_apply_op.chunk_reduce.block),
)
# if idx == 2:
# offset_arg = buf_apply_op.chunk_reduce.block.args[1]
# rewriter.insert_op(
# self._build_extract_slice(op, t, offset_arg),
# InsertPoint.at_end(buf_apply_op.chunk_reduce.block),
# )
chunk_reduce_arg_mapping.append(t.tensor)
else:
chunk_reduce_arg_mapping.append(arg)
Expand All @@ -138,6 +138,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
else:
post_process_arg_mapping.append(arg)

assert isa(typ := op.chunk_reduce.block.args[0].type, TensorType[Attribute])
chunk_type = TensorType(typ.get_element_type(), typ.get_shape()[1:])

# inline blocks from old into new regions
rewriter.inline_block(
op.chunk_reduce.block,
Expand All @@ -151,6 +154,10 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
post_process_arg_mapping,
)

self._inject_iter_arg_into_linalg_outs(
buf_apply_op, rewriter, chunk_type, chunk_reduce_arg_mapping[2]
)

# insert new op
rewriter.replace_matched_op(new_ops=[*to_memrefs, buf_apply_op])

Expand All @@ -170,6 +177,58 @@ def _get_empty_bufferized_region(args: Sequence[BlockArgument]) -> Region:
)
)

@staticmethod
def _inject_iter_arg_into_linalg_outs(
op: csl_stencil.ApplyOp,
rewriter: PatternRewriter,
chunk_type: TensorType[Attribute],
iter_arg: SSAValue,
):
"""
Finds a linalg op with `chunk_type` shape in `outs` and injects
an extracted slice of `iter_arg`. This is a work-around for the
way bufferization works, causing it to use `iter_arg` as an accumulator
and avoiding having an extra alloc + memref.copy.
"""
linalg_op: linalg.NamedOpBase | None = None
for curr_op in op.chunk_reduce.block.ops:
if (
isinstance(curr_op, linalg.NamedOpBase)
and len(curr_op.outputs) > 0
and curr_op.outputs.types[0] == chunk_type
):
linalg_op = curr_op
break

if linalg_op is None:
return

rewriter.replace_op(
linalg_op,
[
extract_slice_op := tensor.ExtractSliceOp(
operands=[iter_arg, [op.chunk_reduce.block.args[1]], [], []],
result_types=[chunk_type],
properties={
"static_offsets": DenseArrayBase.from_list(
i64, (memref.Subview.DYNAMIC_INDEX,)
),
"static_sizes": DenseArrayBase.from_list(
i64, chunk_type.get_shape()
),
"static_strides": DenseArrayBase.from_list(i64, (1,)),
},
),
type(linalg_op).build(
operands=[linalg_op.inputs, extract_slice_op.results],
result_types=linalg_op.result_types,
properties=linalg_op.properties,
attributes=linalg_op.attributes,
regions=[linalg_op.detach_region(r) for r in linalg_op.regions],
),
],
)

@staticmethod
def _build_extract_slice(
op: csl_stencil.ApplyOp, to_tensor: bufferization.ToTensorOp, offset: SSAValue
Expand Down Expand Up @@ -303,7 +362,8 @@ class CslStencilBufferize(ModulePass):
"""
Bufferizes the csl_stencil dialect.

Creates a `tensor.extract_slice` op needed by `lift-arith-to-linalg` and should be run without `cse` in between.
Attempts to inject `csl_stencil.apply.chunk_reduce.iter_arg` into linalg compute ops `outs` within that region
for improved bufferization. Ideally be run after `--lift-arith-to-linalg`.
"""

name = "csl-stencil-bufferize"
Expand Down
Loading