From 486c6ada68d3c7ee665a3dd89b79d07de488b259 Mon Sep 17 00:00:00 2001 From: Nicolai Stawinoga <36768051+n-io@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:53:42 +0200 Subject: [PATCH] transforms: (experimental) minor updates to stencil-tensorize-z-dim (#2741) More efficient use of xdsl functionality. Co-authored-by: n-io --- .../stencil_tensorize_z_dimension.py | 81 ++++++------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 7c364250d2..61ce08eab7 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -1,9 +1,16 @@ -from collections.abc import Callable, Sequence +from collections.abc import Sequence from typing import TypeGuard, cast from attr import dataclass -from xdsl.dialects.arith import Addf, Divf, FloatingPointLikeBinaryOp, Mulf, Subf +from xdsl.dialects.arith import ( + Addf, + BinaryOperation, + Divf, + FloatingPointLikeBinaryOp, + Mulf, + Subf, +) from xdsl.dialects.builtin import ( AnyFloat, ContainerType, @@ -146,11 +153,11 @@ def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /): def arithBinaryOpTensorize( - type_constructor: Callable[..., FloatingPointLikeBinaryOp], op: FloatingPointLikeBinaryOp, rewriter: PatternRewriter, /, ): + type_constructor = type(op) if is_tensor(op.result.type): return if is_tensor(op.lhs.type) and is_tensor(op.rhs.type): @@ -175,28 +182,12 @@ def arithBinaryOpTensorize( ) -class ArithAddfOpTensorize(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: Addf, rewriter: PatternRewriter, /): - arithBinaryOpTensorize(Addf, op, rewriter) - - -class ArithSubfOpTensorize(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: Subf, rewriter: PatternRewriter, /): - arithBinaryOpTensorize(Subf, op, rewriter) - - -class ArithMulfOpTensorize(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: Mulf, rewriter: PatternRewriter, /): - arithBinaryOpTensorize(Mulf, op, rewriter) - - -class ArithDivfOpTensorize(RewritePattern): +class ArithOpTensorize(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: Divf, rewriter: PatternRewriter, /): - arithBinaryOpTensorize(Divf, op, rewriter) + def match_and_rewrite( + self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, / + ): + arithBinaryOpTensorize(op, rewriter) @dataclass(frozen=True) @@ -324,40 +315,24 @@ def match_and_rewrite(self, op: ExtractSliceOp, rewriter: PatternRewriter, /): def arithBinaryOpUpdateShape( - type_constructor: Callable[..., FloatingPointLikeBinaryOp], - op: FloatingPointLikeBinaryOp, + op: BinaryOperation[Attribute], rewriter: PatternRewriter, /, ): + type_constructor = type(op) if typ := get_required_result_type(op): if needs_update_shape(op.result.type, typ): rewriter.replace_matched_op( - type_constructor(op.lhs, op.rhs, flags=None, result_type=typ) + type_constructor(op.lhs, op.rhs, result_type=typ) ) -class ArithAddfOpUpdateShape(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: Addf, rewriter: PatternRewriter, /): - arithBinaryOpUpdateShape(Addf, op, rewriter) - - -class ArithSubfOpUpdateShape(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: Subf, rewriter: PatternRewriter, /): - arithBinaryOpUpdateShape(Subf, op, rewriter) - - -class ArithMulfOpUpdateShape(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: Mulf, rewriter: PatternRewriter, /): - arithBinaryOpUpdateShape(Mulf, op, rewriter) - - -class ArithDivfOpUpdateShape(RewritePattern): +class ArithOpUpdateShape(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: Divf, rewriter: PatternRewriter, /): - arithBinaryOpUpdateShape(Divf, op, rewriter) + def match_and_rewrite( + self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, / + ): + arithBinaryOpUpdateShape(op, rewriter) class EmptyOpUpdateShape(RewritePattern): @@ -402,10 +377,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: GreedyRewritePatternApplier( [ AccessOpTensorize(), - ArithAddfOpTensorize(), - ArithMulfOpTensorize(), - ArithSubfOpTensorize(), - ArithDivfOpTensorize(), + ArithOpTensorize(), ] ), walk_reverse=False, @@ -419,10 +391,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: ExtractSliceOpUpdateShape(), EmptyOpUpdateShape(), FillOpUpdateShape(), - ArithAddfOpUpdateShape(), - ArithSubfOpUpdateShape(), - ArithMulfOpUpdateShape(), - ArithDivfOpUpdateShape(), + ArithOpUpdateShape(), ] ), walk_reverse=True,