Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stencil.return operand type promotion #119

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
SSAargs = (self._visit_math_nodes(dim, arg, output_indexed)
for arg in node.args)
return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs)

# Trigonometric functions
elif isinstance(node, sin):
assert len(node.args) == 1, "Expected single argument for sin."
Expand All @@ -298,13 +298,13 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
assert len(node.args) == 1, "Expected single argument for cos."
return math.CosOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, tan):
assert len(node.args) == 1, "Expected single argument for TanOp."

return math.TanOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, Relational):
if isinstance(node, GreaterThan):
mnemonic = "sge"
Expand Down Expand Up @@ -382,7 +382,20 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None:
self.function_values |= self.apply_temps

with ImplicitBuilder(apply.region.block):
stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)])
result = self._visit_math_nodes(dim, eq.rhs, eq.lhs)
expected_type = apply.res[0].type.get_element_type()
match expected_type:
case result.type:
pass
case builtin.f32:
if result.type == IndexType():
result = arith.IndexCastOp(result, builtin.i64).result
result = arith.SIToFPOp(result, builtin.f32).result
case builtin.IndexType:
result = arith.IndexCastOp(result, IndexType()).result
case _:
raise Exception(f"Unexpected result type {type(result)}")
stencil.ReturnOp.get([result])

lb = stencil.IndexAttr.get(*([0] * len(shape)))
ub = stencil.IndexAttr.get(*shape)
Expand Down Expand Up @@ -439,7 +452,6 @@ def build_condition(self, dim: SteppingDimension, eq: BooleanFunction):
self.build_generic_step_expression(dim, eq)
scf.Yield()


def build_time_loop(
self, eqs: list[Any], step_dim: SteppingDimension, **kwargs
):
Expand All @@ -450,7 +462,7 @@ def build_time_loop(
ub = iet_ssa.LoadSymbolic.get(
step_dim.symbolic_max._C_name, IndexType()
)

one = arith.Constant.from_int_and_width(1, IndexType())

# Devito iterates from time_m to time_M *inclusive*, MLIR only takes
Expand Down Expand Up @@ -497,7 +509,7 @@ def build_time_loop(
for i, (f, t) in enumerate(self.time_buffers)
}
self.function_values |= self.block_args

# Name the block argument for debugging
for (f, t), arg in self.block_args.items():
arg.name_hint = f"{f.name}_t{t}"
Expand All @@ -513,8 +525,7 @@ def build_time_loop(

def lower_devito_Eqs(self, eqs: list[Any], **kwargs):
# Lower devito Equations to xDSL



for eq in eqs:
lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs)
if isinstance(eq, Eq):
Expand Down Expand Up @@ -546,7 +557,7 @@ def _lower_injection(self, eqs: list[LoweredEq]):
lb = arith.Constant.from_int_and_width(int(lower), IndexType())
else:
raise NotImplementedError(f"Lower bound of type {type(lower)} not supported")

try:
name = interval.dim.symbolic_min.name
except:
Expand Down Expand Up @@ -633,7 +644,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
# Instantiate the module.
self.function_values: dict[tuple[Function, int], SSAValue] = {}
self.symbol_values: dict[str, SSAValue] = {}

module = ModuleOp(Region([block := Block([])]))
with ImplicitBuilder(block):
# Get all functions used in the equations
Expand All @@ -647,7 +658,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
functions.add(f.function)

elif isinstance(eq, Injection):

functions.add(eq.field.function)
for f in retrieve_functions(eq.expr):
if isinstance(f, PointSource):
Expand Down
49 changes: 23 additions & 26 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,20 @@ def test_function_IV():
assert np.isclose(norm(u), devito_norm_u)


def test_function_V():
grid = Grid(shape=(5, 5))
x, y = grid.dimensions

f = Function(name="f", grid=grid)

eqns = [Eq(f, 2)]

op = Operator(eqns, opt="xdsl")
op.apply()

assert np.all(f.data == 2)


class TestTrigonometric(object):

@pytest.mark.parametrize('deg, exp', ([90.0, 3.5759869], [30.0, 3.9521265],
Expand Down Expand Up @@ -1028,37 +1042,20 @@ def test_tan(self, deg, exp):
assert np.isclose(norm(u), exp, rtol=1e-4)


class TestOperatorUnsupported(object):
def test_forward_assignment():
# simple forward assignment

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_forward_assignment(self):
# simple forward assignment

grid = Grid(shape=(4, 4))
u = TimeFunction(name="u", grid=grid, space_order=2)
u.data[:, :, :] = 0

eq0 = Eq(u.forward, 1)

op = Operator([eq0], opt='xdsl')

op.apply(time_M=1)

assert np.isclose(norm(u), 5.6584, rtol=0.001)

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_function(self):
grid = Grid(shape=(5, 5))
x, y = grid.dimensions
grid = Grid(shape=(4, 4))
u = TimeFunction(name="u", grid=grid, space_order=2)
u.data[:, :, :] = 0

f = Function(name="f", grid=grid)
eq0 = Eq(u.forward, 1)

eqns = [Eq(f, 2)]
op = Operator([eq0], opt='xdsl')

op = Operator(eqns, opt='xdsl')
op.apply()
op.apply(time_M=1)

assert np.all(f.data == 4)
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
assert np.isclose(norm(u), 5.6584, rtol=0.001)


class TestElastic():
Expand Down