From 06293fee2a0630ffac8b5f1f6c6ee82786f37028 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Wed, 21 Aug 2024 17:47:52 +0100 Subject: [PATCH] Implement return type conversion. --- devito/ir/xdsl_iet/cluster_to_ssa.py | 37 +++++++++++++-------- tests/test_xdsl_base.py | 49 +++++++++++++--------------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/devito/ir/xdsl_iet/cluster_to_ssa.py b/devito/ir/xdsl_iet/cluster_to_ssa.py index 37713b5322..66a9741c35 100644 --- a/devito/ir/xdsl_iet/cluster_to_ssa.py +++ b/devito/ir/xdsl_iet/cluster_to_ssa.py @@ -287,7 +287,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args) return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs) - + # Trigonometric functions elif isinstance(node, sin): assert len(node.args) == 1, "Expected single argument for sin." @@ -298,13 +298,13 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, assert len(node.args) == 1, "Expected single argument for cos." return math.CosOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result - + elif isinstance(node, tan): assert len(node.args) == 1, "Expected single argument for TanOp." - + return math.TanOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result - + elif isinstance(node, Relational): if isinstance(node, GreaterThan): mnemonic = "sge" @@ -382,7 +382,20 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None: self.function_values |= self.apply_temps with ImplicitBuilder(apply.region.block): - stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)]) + result = self._visit_math_nodes(dim, eq.rhs, eq.lhs) + expected_type = apply.res[0].type.get_element_type() + match expected_type: + case result.type: + pass + case builtin.f32: + if result.type == IndexType(): + result = arith.IndexCastOp(result, builtin.i64).result + result = arith.SIToFPOp(result, builtin.f32).result + case builtin.IndexType: + result = arith.IndexCastOp(result, IndexType()).result + case _: + raise Exception(f"Unexpected result type {type(result)}") + stencil.ReturnOp.get([result]) lb = stencil.IndexAttr.get(*([0] * len(shape))) ub = stencil.IndexAttr.get(*shape) @@ -439,7 +452,6 @@ def build_condition(self, dim: SteppingDimension, eq: BooleanFunction): self.build_generic_step_expression(dim, eq) scf.Yield() - def build_time_loop( self, eqs: list[Any], step_dim: SteppingDimension, **kwargs ): @@ -450,7 +462,7 @@ def build_time_loop( ub = iet_ssa.LoadSymbolic.get( step_dim.symbolic_max._C_name, IndexType() ) - + one = arith.Constant.from_int_and_width(1, IndexType()) # Devito iterates from time_m to time_M *inclusive*, MLIR only takes @@ -497,7 +509,7 @@ def build_time_loop( for i, (f, t) in enumerate(self.time_buffers) } self.function_values |= self.block_args - + # Name the block argument for debugging for (f, t), arg in self.block_args.items(): arg.name_hint = f"{f.name}_t{t}" @@ -513,8 +525,7 @@ def build_time_loop( def lower_devito_Eqs(self, eqs: list[Any], **kwargs): # Lower devito Equations to xDSL - - + for eq in eqs: lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) if isinstance(eq, Eq): @@ -546,7 +557,7 @@ def _lower_injection(self, eqs: list[LoweredEq]): lb = arith.Constant.from_int_and_width(int(lower), IndexType()) else: raise NotImplementedError(f"Lower bound of type {type(lower)} not supported") - + try: name = interval.dim.symbolic_min.name except: @@ -633,7 +644,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp: # Instantiate the module. self.function_values: dict[tuple[Function, int], SSAValue] = {} self.symbol_values: dict[str, SSAValue] = {} - + module = ModuleOp(Region([block := Block([])])) with ImplicitBuilder(block): # Get all functions used in the equations @@ -647,7 +658,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp: functions.add(f.function) elif isinstance(eq, Injection): - + functions.add(eq.field.function) for f in retrieve_functions(eq.expr): if isinstance(f, PointSource): diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index 39aa97828b..23eeaf97ce 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -972,6 +972,20 @@ def test_function_IV(): assert np.isclose(norm(u), devito_norm_u) +def test_function_V(): + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + + f = Function(name="f", grid=grid) + + eqns = [Eq(f, 2)] + + op = Operator(eqns, opt="xdsl") + op.apply() + + assert np.all(f.data == 2) + + class TestTrigonometric(object): @pytest.mark.parametrize('deg, exp', ([90.0, 3.5759869], [30.0, 3.9521265], @@ -1028,37 +1042,20 @@ def test_tan(self, deg, exp): assert np.isclose(norm(u), exp, rtol=1e-4) -class TestOperatorUnsupported(object): +def test_forward_assignment(): + # simple forward assignment - @pytest.mark.xfail(reason="stencil.return operation does not verify for i64") - def test_forward_assignment(self): - # simple forward assignment - - grid = Grid(shape=(4, 4)) - u = TimeFunction(name="u", grid=grid, space_order=2) - u.data[:, :, :] = 0 - - eq0 = Eq(u.forward, 1) - - op = Operator([eq0], opt='xdsl') - - op.apply(time_M=1) - - assert np.isclose(norm(u), 5.6584, rtol=0.001) - - @pytest.mark.xfail(reason="stencil.return operation does not verify for i64") - def test_function(self): - grid = Grid(shape=(5, 5)) - x, y = grid.dimensions + grid = Grid(shape=(4, 4)) + u = TimeFunction(name="u", grid=grid, space_order=2) + u.data[:, :, :] = 0 - f = Function(name="f", grid=grid) + eq0 = Eq(u.forward, 1) - eqns = [Eq(f, 2)] + op = Operator([eq0], opt='xdsl') - op = Operator(eqns, opt='xdsl') - op.apply() + op.apply(time_M=1) - assert np.all(f.data == 4) + assert np.isclose(norm(u), 5.6584, rtol=0.001) class TestElastic():