Skip to content

Commit

Permalink
transforms: (experimental) minor updates to stencil-tensorize-z-dim (#…
Browse files Browse the repository at this point in the history
…2741)

More efficient use of xdsl functionality.

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Jun 17, 2024
1 parent 4579ea7 commit 486c6ad
Showing 1 changed file with 25 additions and 56 deletions.
81 changes: 25 additions & 56 deletions xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import TypeGuard, cast

from attr import dataclass

from xdsl.dialects.arith import Addf, Divf, FloatingPointLikeBinaryOp, Mulf, Subf
from xdsl.dialects.arith import (
Addf,
BinaryOperation,
Divf,
FloatingPointLikeBinaryOp,
Mulf,
Subf,
)
from xdsl.dialects.builtin import (
AnyFloat,
ContainerType,
Expand Down Expand Up @@ -146,11 +153,11 @@ def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):


def arithBinaryOpTensorize(
type_constructor: Callable[..., FloatingPointLikeBinaryOp],
op: FloatingPointLikeBinaryOp,
rewriter: PatternRewriter,
/,
):
type_constructor = type(op)
if is_tensor(op.result.type):
return
if is_tensor(op.lhs.type) and is_tensor(op.rhs.type):
Expand All @@ -175,28 +182,12 @@ def arithBinaryOpTensorize(
)


class ArithAddfOpTensorize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Addf, rewriter: PatternRewriter, /):
arithBinaryOpTensorize(Addf, op, rewriter)


class ArithSubfOpTensorize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Subf, rewriter: PatternRewriter, /):
arithBinaryOpTensorize(Subf, op, rewriter)


class ArithMulfOpTensorize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Mulf, rewriter: PatternRewriter, /):
arithBinaryOpTensorize(Mulf, op, rewriter)


class ArithDivfOpTensorize(RewritePattern):
class ArithOpTensorize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Divf, rewriter: PatternRewriter, /):
arithBinaryOpTensorize(Divf, op, rewriter)
def match_and_rewrite(
self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, /
):
arithBinaryOpTensorize(op, rewriter)


@dataclass(frozen=True)
Expand Down Expand Up @@ -324,40 +315,24 @@ def match_and_rewrite(self, op: ExtractSliceOp, rewriter: PatternRewriter, /):


def arithBinaryOpUpdateShape(
type_constructor: Callable[..., FloatingPointLikeBinaryOp],
op: FloatingPointLikeBinaryOp,
op: BinaryOperation[Attribute],
rewriter: PatternRewriter,
/,
):
type_constructor = type(op)
if typ := get_required_result_type(op):
if needs_update_shape(op.result.type, typ):
rewriter.replace_matched_op(
type_constructor(op.lhs, op.rhs, flags=None, result_type=typ)
type_constructor(op.lhs, op.rhs, result_type=typ)
)


class ArithAddfOpUpdateShape(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Addf, rewriter: PatternRewriter, /):
arithBinaryOpUpdateShape(Addf, op, rewriter)


class ArithSubfOpUpdateShape(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Subf, rewriter: PatternRewriter, /):
arithBinaryOpUpdateShape(Subf, op, rewriter)


class ArithMulfOpUpdateShape(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Mulf, rewriter: PatternRewriter, /):
arithBinaryOpUpdateShape(Mulf, op, rewriter)


class ArithDivfOpUpdateShape(RewritePattern):
class ArithOpUpdateShape(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Divf, rewriter: PatternRewriter, /):
arithBinaryOpUpdateShape(Divf, op, rewriter)
def match_and_rewrite(
self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, /
):
arithBinaryOpUpdateShape(op, rewriter)


class EmptyOpUpdateShape(RewritePattern):
Expand Down Expand Up @@ -402,10 +377,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
GreedyRewritePatternApplier(
[
AccessOpTensorize(),
ArithAddfOpTensorize(),
ArithMulfOpTensorize(),
ArithSubfOpTensorize(),
ArithDivfOpTensorize(),
ArithOpTensorize(),
]
),
walk_reverse=False,
Expand All @@ -419,10 +391,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
ExtractSliceOpUpdateShape(),
EmptyOpUpdateShape(),
FillOpUpdateShape(),
ArithAddfOpUpdateShape(),
ArithSubfOpUpdateShape(),
ArithMulfOpUpdateShape(),
ArithDivfOpUpdateShape(),
ArithOpUpdateShape(),
]
),
walk_reverse=True,
Expand Down

0 comments on commit 486c6ad

Please sign in to comment.