Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(minor): allow csl_stencil.access to operate on own data #2777

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/filecheck/dialects/csl/csl-stencil-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ builtin.module {
%4 = arith.constant 1.666600e-01 : f32
%5 = csl_stencil.access %3[1, 0] : memref<4xtensor<510xf32>>
%6 = csl_stencil.access %3[-1, 0] : memref<4xtensor<510xf32>>
%7 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%8 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%7 = csl_stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%8 = csl_stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%9 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%10 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%11 = csl_stencil.access %3[0, 1] : memref<4xtensor<510xf32>>
Expand Down Expand Up @@ -38,8 +38,8 @@ builtin.module {
// CHECK-NEXT: %4 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %5 = csl_stencil.access %3[1, 0] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %6 = csl_stencil.access %3[-1, 0] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %7 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %8 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %7 = csl_stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %8 = csl_stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %9 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %10 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %11 = csl_stencil.access %3[0, 1] : memref<4xtensor<510xf32>>
Expand Down Expand Up @@ -70,8 +70,8 @@ builtin.module {
// CHECK-GENERIC-NEXT: %4 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32
// CHECK-GENERIC-NEXT: %5 = "csl_stencil.access"(%3) <{"offset" = #stencil.index[1, 0], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %6 = "csl_stencil.access"(%3) <{"offset" = #stencil.index[-1, 0], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %7 = "stencil.access"(%2) {"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-GENERIC-NEXT: %8 = "stencil.access"(%2) {"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-GENERIC-NEXT: %7 = "csl_stencil.access"(%2) <{"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-GENERIC-NEXT: %8 = "csl_stencil.access"(%2) <{"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-GENERIC-NEXT: %9 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %10 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %11 = "csl_stencil.access"(%3) <{"offset" = #stencil.index[0, 1], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/transforms/stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ builtin.module {
// CHECK-NEXT: %5 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %6 = csl_stencil.access %4[1, 0] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %7 = csl_stencil.access %4[-1, 0] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %8 = stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %8 = csl_stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %10 = stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %10 = csl_stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %12 = csl_stencil.access %4[0, 1] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %13 = csl_stencil.access %4[0, -1] : memref<4xtensor<510xf32>>
Expand Down
29 changes: 23 additions & 6 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,17 @@ def __init__(
@irdl_op_definition
class AccessOp(IRDLOperation):
"""
A CSL stencil access that operates on data prefetched by `csl_stencil.prefetch`
A CSL stencil access that operates on own data or data prefetched from neighbors via `csl_stencil.prefetch`

The source of data determines the type `op` is required to have:

${type(op) == memref.MemRefType} - for accesses to data prefetched from neighbors
${type(op) == stencil.Temp} - for accesses to own data
Comment on lines +131 to +132
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit concerned about the reckless mixing of value and reference semantics here. Not sure what to think of it...

Copy link
Collaborator Author

@n-io n-io Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a valid concern, and I suppose is owed to accesses to neighbour data operating on prefetched data, having effectively gone through one more round of lowering as compared to accesses to own data.

The plan was to not touch accesses to own data and leave them as stencil.access (without lowering to csl_stencil.access). What this means in the code is that there already is a mixing of value and reference semantics, just on different ops.

The question then partly becomes, how to do this cleanly. I see three ways:

  • allow stencil.access to have parents other than stencil.apply
  • allow mixing of value semantics (own data) and reference semantics (neighbour data), as proposed here
  • separate the ops into something along the lines of csl_stencil.access_own_data and csl_stencil.access_neighbor_data - this would work but not be my personal preference because of potential confusion around the get_accesses function. The ops would basically be the same and there may not be any additional gain in this split

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PapyChacal what would you suggest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if in doubt, press merge. Having slept over this for a night, I think it's a symptom of us not having figured out quite how/when bufferisation kicks in. Maybe this isn't a problem we need to solve, maybe it is.

Let's see where this PR takes us instead of sweating too much with the semantics right now.


"""

name = "csl_stencil.access"
op = operand_def(memref.MemRefType)
op = operand_def(memref.MemRefType | stencil.TempType)
offset = prop_def(stencil.IndexAttr)
offset_mapping = opt_prop_def(stencil.IndexAttr)
result = result_def(TensorType)
Expand Down Expand Up @@ -206,15 +212,26 @@ def parse(cls, parser: Parser):
props["offset_mapping"] = stencil.IndexAttr.get(*offset_mapping)
parser.parse_punctuation(":")
res_type = parser.parse_attribute()
if not isa(res_type, memref.MemRefType[Attribute]):
parser.raise_error("Expected return type to be a memref")
if not isa(
res_type, memref.MemRefType[Attribute] | stencil.TempType[Attribute]
):
parser.raise_error("Expected return type to be a memref or stencil.temp")
return cls.build(
operands=[temp], result_types=[res_type.element_type], properties=props
)

def verify_(self) -> None:
assert isa(self.op.type, memref.MemRefType[Attribute])
assert self.result.type == self.op.type.get_element_type()
if tuple(self.offset) == (0, 0):
if not isa(self.op.type, stencil.TempType[Attribute]):
raise VerifyException(
f"{type(self)} access to own data requires type stencil.TempType but found {self.op.type}"
)
assert self.result.type == self.op.type.get_element_type()
else:
if not isa(self.op.type, memref.MemRefType[Attribute]):
raise VerifyException(
f"{type(self)} access to neighbor data requires type memref.MemRefType but found {self.op.type}"
)

# As promised by HasAncestor(ApplyOp)
trait = cast(HasAncestor, AccessOp.get_trait(HasAncestor, (stencil.ApplyOp,)))
Expand Down
19 changes: 15 additions & 4 deletions xdsl/transforms/stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,22 @@ class ConvertAccessOpFromPrefetchPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.AccessOp, rewriter: PatternRewriter, /):
assert len(op.offset) == 2
if (
tuple(op.offset) == (0, 0)
or op.temp != op.get_apply().region.block.args[self.arg_index]
):
if op.temp != op.get_apply().region.block.args[self.arg_index]:
return

# translate access to own data, which operates on stencil.TempType
if tuple(op.offset) == (0, 0):
assert isa(op.res.type, stencil.TensorType)
rewriter.replace_matched_op(
csl_stencil.AccessOp(
op=op.temp,
offset=op.offset,
offset_mapping=op.offset_mapping,
result_type=op.res.type,
)
)
return

prefetched_arg = op.get_apply().region.block.args[-1]
assert isa(m_type := prefetched_arg.type, memref.MemRefType[Attribute])
assert isa(t_type := m_type.get_element_type(), TensorType[Attribute])
Expand Down
Loading