Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Implement shape-inference pass #3047

Merged
merged 12 commits into from
Aug 17, 2024
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: XDSL_ROUNDTRIP
// RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference | filecheck %s --check-prefix SHAPE
// RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference,convert-stencil-to-ll-mlir | filecheck %s --check-prefix MLIR
// RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference,stencil-bufferize | filecheck %s --check-prefix BUFF
// RUN: xdsl-opt %s -p stencil-storage-materialization,shape-inference | filecheck %s --check-prefix SHAPE
// RUN: xdsl-opt %s -p stencil-storage-materialization,shape-inference,convert-stencil-to-ll-mlir | filecheck %s --check-prefix MLIR
// RUN: xdsl-opt %s -p stencil-storage-materialization,shape-inference,stencil-bufferize | filecheck %s --check-prefix BUFF

func.func @fvtp2d_qi(%arg0: !stencil.field<?x?x?xf64>, %arg1: !stencil.field<?x?x?xf64>, %arg2: !stencil.field<?x?x?xf64>, %arg3: !stencil.field<?x?x?xf64>, %arg4: !stencil.field<?x?x?xf64>, %arg5: !stencil.field<?x?x?xf64>, %arg6: !stencil.field<?x?x?xf64>) attributes {stencil.program} {
%0 = stencil.cast %arg0 : !stencil.field<?x?x?xf64> -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/transforms/stencil-shape-inference.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xdsl-opt -p stencil-shape-inference --verify-diagnostics --split-input-file %s | filecheck %s
// RUN: xdsl-opt -p shape-inference --verify-diagnostics --split-input-file %s | filecheck %s

builtin.module {
func.func @different_input_offsets(%out : !stencil.field<[-4,68]xf64>, %left : !stencil.field<[-4,68]xf64>, %right : !stencil.field<[-4,68]xf64>) {
Expand Down
2 changes: 1 addition & 1 deletion tests/xdsl_opt/test_xdsl_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp):

def test_print_between_passes():
filename_in = "tests/xdsl_opt/empty_program.mlir"
passes = ["stencil-shape-inference", "dce", "frontend-desymrefy"]
passes = ["shape-inference", "dce", "frontend-desymrefy"]
flags = ["--print-between-passes", "-p", ",".join(passes)]

f = StringIO("")
Expand Down
102 changes: 95 additions & 7 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
HasAncestor,
HasCanonicalizationPatternsTrait,
HasParent,
HasShapeInferencePatternsTrait,
IsolatedFromAbove,
IsTerminator,
MemoryEffect,
Expand Down Expand Up @@ -420,6 +421,16 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
)


class ApplyOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.shape_inference_patterns.stencil import (
ApplyOpShapeInference,
)

return (ApplyOpShapeInference(),)


class ApplyMemoryEffect(RecursiveMemoryEffect):
@classmethod
def get_effects(cls, op: Operation):
Expand Down Expand Up @@ -464,6 +475,7 @@ class ApplyOp(IRDLOperation):
[
IsolatedFromAbove(),
ApplyOpHasCanonicalizationPatternsTrait(),
ApplyOpHasShapeInferencePatternsTrait(),
ApplyMemoryEffect(),
]
)
Expand Down Expand Up @@ -752,6 +764,16 @@ def verify_(self) -> None:
)


class CombineOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.shape_inference_patterns.stencil import (
CombineOpShapeInference,
)

return (CombineOpShapeInference(),)


@irdl_op_definition
class CombineOp(IRDLOperation):
"""
Expand Down Expand Up @@ -790,11 +812,29 @@ class CombineOp(IRDLOperation):
upperext = var_operand_def(TempType)
results_ = var_result_def(TempType)

traits = frozenset([Pure()])
traits = frozenset(
[
Pure(),
CombineOpHasShapeInferencePatternsTrait(),
]
)

assembly_format = "$dim `at` $index `lower` `=` `(` $lower `:` type($lower) `)` `upper` `=` `(` $upper `:` type($upper) `)` (`lowerext` `=` $lowerext^ `:` type($lowerext))? (`upperext` `=` $upperext^ `:` type($upperext))? attr-dict-with-keyword `:` type($results_)"

irdl_options = [AttrSizedOperandSegments(), Pure()]
irdl_options = [
AttrSizedOperandSegments(),
Pure(),
]


class DynAccessOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls):
from xdsl.transforms.shape_inference_patterns.stencil import (
DynAccessOpShapeInference,
)

return (DynAccessOpShapeInference(),)


@irdl_op_definition
Expand Down Expand Up @@ -840,7 +880,13 @@ class DynAccessOp(IRDLOperation):
"$temp `[` $offset `]` `in` $lb `:` $ub attr-dict-with-keyword `:` type($temp)"
)

traits = frozenset([HasAncestor(ApplyOp), NoMemoryEffect()])
traits = frozenset(
[
HasAncestor(ApplyOp),
NoMemoryEffect(),
DynAccessOpHasShapeInferencePatternsTrait(),
]
)

def __init__(
self,
Expand Down Expand Up @@ -938,6 +984,16 @@ def get_apply(self):
return cast(ApplyOp, ancestor)


class AccessOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.shape_inference_patterns.stencil import (
AccessOpShapeInference,
)

return (AccessOpShapeInference(),)


@irdl_op_definition
class AccessOp(IRDLOperation):
"""
Expand Down Expand Up @@ -976,7 +1032,9 @@ class AccessOp(IRDLOperation):
)
)

traits = frozenset([HasAncestor(ApplyOp), Pure()])
traits = frozenset(
[HasAncestor(ApplyOp), Pure(), AccessOpHasShapeInferencePatternsTrait()]
)

def print(self, printer: Printer):
printer.print(" ")
Expand Down Expand Up @@ -1143,6 +1201,16 @@ def get_apply(self):
return cast(ApplyOp, ancestor)


class LoadOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.shape_inference_patterns.stencil import (
LoadOpShapeInference,
)

return (LoadOpShapeInference(),)


class LoadOpMemoryEffect(MemoryEffect):
@classmethod
def get_effects(cls, op: Operation):
Expand Down Expand Up @@ -1187,7 +1255,7 @@ class LoadOp(IRDLOperation):

assembly_format = "$field attr-dict-with-keyword `:` type($field) `->` type($res)"

traits = frozenset([LoadOpMemoryEffect()])
traits = frozenset([LoadOpHasShapeInferencePatternsTrait(), LoadOpMemoryEffect()])

@staticmethod
def get(
Expand Down Expand Up @@ -1225,6 +1293,16 @@ def verify_(self) -> None:
)


class BufferOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.shape_inference_patterns.stencil import (
BufferOpShapeInference,
)

return (BufferOpShapeInference(),)


@irdl_op_definition
class BufferOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1271,7 +1349,7 @@ class BufferOp(IRDLOperation):

assembly_format = "$temp attr-dict-with-keyword `:` type($temp) `->` type($res)"

traits = frozenset([Pure()])
traits = frozenset([Pure(), BufferOpHasShapeInferencePatternsTrait()])

def __init__(self, temp: SSAValue | Operation):
temp = SSAValue.get(temp)
Expand Down Expand Up @@ -1311,6 +1389,16 @@ def verify(
super().verify(attr, constraint_context)


class StoreOpHasShapeInferencePatternsTrait(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.shape_inference_patterns.stencil import (
StoreOpShapeInference,
)

return (StoreOpShapeInference(),)


class StoreOpMemoryEffect(MemoryEffect):
@classmethod
def get_effects(cls, op: Operation):
Expand Down Expand Up @@ -1369,7 +1457,7 @@ class StoreOp(IRDLOperation):

assembly_format = "$temp `to` $field `` `(` $bounds `)` attr-dict-with-keyword `:` type($temp) `to` type($field)"

traits = frozenset([StoreOpMemoryEffect()])
traits = frozenset([StoreOpHasShapeInferencePatternsTrait(), StoreOpMemoryEffect()])

@staticmethod
def get(
Expand Down
8 changes: 4 additions & 4 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,10 @@ def get_lower_scf_for_to_labels():

return riscv_scf_to_asm.LowerScfForToLabels

def get_stencil_shape_inference():
from xdsl.transforms.experimental import stencil_shape_inference
def get_shape_inference():
from xdsl.transforms.shape_inference import ShapeInferencePass

return stencil_shape_inference.StencilShapeInferencePass
return ShapeInferencePass

def get_stencil_storage_materialization():
from xdsl.transforms.experimental import stencil_storage_materialization
Expand Down Expand Up @@ -440,7 +440,7 @@ def get_stencil_bufferize():
"scf-parallel-loop-tiling": get_scf_parallel_loop_tiling,
"snitch-allocate-registers": get_snitch_register_allocation,
"riscv-prologue-epilogue-insertion": get_riscv_prologue_epilogue_insertion,
"stencil-shape-inference": get_stencil_shape_inference,
"shape-inference": get_shape_inference,
"stencil-storage-materialization": get_stencil_storage_materialization,
"stencil-tensorize-z-dimension": get_stencil_tensorize_z_dimension,
"stencil-to-csl-stencil": get_stencil_to_csl_stencil,
Expand Down
17 changes: 17 additions & 0 deletions xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,23 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
raise NotImplementedError()


@dataclass(frozen=True)
class HasShapeInferencePatternsTrait(OpTrait):
"""
Provides the rewrite passes to shape infer an operation.

Each rewrite pattern must have the trait's op as root.
"""

def verify(self, op: Operation) -> None:
return

@classmethod
@abc.abstractmethod
def get_shape_inference_patterns(cls) -> tuple[RewritePattern, ...]:
raise NotImplementedError()


class MemoryEffectKind(Enum):
"""
The kind of side effect an operation can have.
Expand Down
6 changes: 2 additions & 4 deletions xdsl/transforms/experimental/dmp/stencil_global_to_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
GridSlice2d,
GridSlice3d,
)
from xdsl.transforms.experimental.stencil_shape_inference import (
StencilShapeInferencePass,
)
from xdsl.transforms.shape_inference import ShapeInferencePass
from xdsl.utils.hints import isa

_T = TypeVar("_T", bound=Attribute)
Expand Down Expand Up @@ -684,7 +682,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
).rewrite_module(op)

# run the shape inference pass
StencilShapeInferencePass().apply(ctx, op)
ShapeInferencePass().apply(ctx, op)

DmpSwapShapeInference(strategy).apply(op)

Expand Down
36 changes: 36 additions & 0 deletions xdsl/transforms/shape_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.ir import Operation
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
)
from xdsl.traits import HasShapeInferencePatternsTrait


class ShapeInferenceRewritePattern(RewritePattern):
"""Rewrite pattern that applies a shape inference pattern."""

def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
trait = op.get_trait(HasShapeInferencePatternsTrait)
if trait is None:
return
patterns = trait.get_shape_inference_patterns()
if len(patterns) == 1:
patterns[0].match_and_rewrite(op, rewriter)
return
GreedyRewritePatternApplier(list(patterns)).match_and_rewrite(op, rewriter)


class ShapeInferencePass(ModulePass):
"""
Applies all shape inference patterns.
"""

name = "shape-inference"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(ShapeInferenceRewritePattern()).rewrite_module(op)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import reduce
from typing import TypeVar, cast

from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects.stencil import (
AccessOp,
ApplyOp,
Expand All @@ -18,11 +16,8 @@
TempType,
)
from xdsl.ir import Attribute, Block, Operation, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
Expand Down Expand Up @@ -268,24 +263,3 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):
return
op.temp.type = op.res.type
update_result_size(op.temp, res_bounds, rewriter)


ShapeInference = GreedyRewritePatternApplier(
[
AccessOpShapeInference(),
ApplyOpShapeInference(),
BufferOpShapeInference(),
CombineOpShapeInference(),
DynAccessOpShapeInference(),
LoadOpShapeInference(),
StoreOpShapeInference(),
]
)


class StencilShapeInferencePass(ModulePass):
name = "stencil-shape-inference"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
inference_walker = PatternRewriteWalker(ShapeInference)
inference_walker.rewrite_module(op)
Loading