diff --git a/devito/ir/xdsl_iet/cluster_to_ssa.py b/devito/ir/xdsl_iet/cluster_to_ssa.py index a13c1d521e..b0c8bd9e2f 100644 --- a/devito/ir/xdsl_iet/cluster_to_ssa.py +++ b/devito/ir/xdsl_iet/cluster_to_ssa.py @@ -17,7 +17,7 @@ from devito.tools.utils import as_tuple from devito.types.basic import Scalar from devito.types.dense import DiscreteFunction, Function, TimeFunction -from devito.types.dimension import SpaceDimension, TimeDimension +from devito.types.dimension import SpaceDimension, TimeDimension, ConditionalDimension from devito.types.equation import Eq # ------------- xdsl imports -------------# @@ -100,10 +100,11 @@ def __init__(self, operator: type[Operator]): self.temps = dict() self.operator = operator - def convert_function_eq(self, eq: LoweredEq, **kwargs): - # Read the grid containing necessary discretization information - # (size, halo width, ...) + def lower_Function(self, eq: LoweredEq, **kwargs): + + # Get the LHS of the equation, where we write write_function = eq.lhs.function + # Get Grid and stepping dimension grid: Grid = write_function.grid step_dim = grid.stepping_dim @@ -118,13 +119,12 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs): dims = retrieve_dimensions(eq.lhs.indices) - if not all(isinstance(d, (SteppingDimension, SpaceDimension)) for d in dims): - self.build_generic_step(step_dim, eq) + if any(isinstance(d, (ConditionalDimension)) for d in dims): + self.build_condition(step_dim, eq) else: - # Get the function carriers of the equation self.build_stencil_step(step_dim, eq) - def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): + def lower_Symbol(self, symbol: Symbol, rhs: LoweredEq, **kwargs): """ Convert a symbol equation to xDSL. """ @@ -135,8 +135,9 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): """ # Docs here Need rewriting - This converts a Devito LoweredEq to IR implementing it. - e.g. + Convert a Devito LoweredEq to IR implementing it. + e.g.: + ```python Eq(u[t + 1, x + 1], u[t, x + 1] + 1) ``` @@ -145,8 +146,8 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): Grid[extent=(1.0,), shape=(3,), dimensions=(x,)] ``` - 1. Create a stencil.apply op to implement the equation, with a classical AST - translation. + 1. Create a stencil.apply op to implement the equation, + with a classical AST translation. The example above would be translated to: ```mlir @@ -163,7 +164,7 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): "stencil.store"(%4, %u_t1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<3>} : (!stencil.temp, !stencil.field<[-1,4]xf32>) -> () ``` """ - # Get the left hand side, called "output function" here because it tells us + # Get the LHS, "write function" telling us where to here because it tells us # Where to write the results of each step. write_function = eq.lhs @@ -171,13 +172,14 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): case Indexed(): match write_function.function: case TimeFunction() | Function(): - self.convert_function_eq(eq, **kwargs) + self.lower_Function(eq, **kwargs) case _: + type_error = type(write_function.function) raise NotImplementedError( - f"Function of type {type(write_function.function)} not supported" # noqa + f"Function of type {type_error} not supported" ) case Symbol(): - self.convert_symbol_eq(write_function, eq.rhs, **kwargs) + self.lower_Symbol(write_function, eq.rhs, **kwargs) case _: raise NotImplementedError(f"LHS of type {type(write_function)} not supported") # noqa @@ -322,15 +324,16 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None: None """ read_functions = OrderedSet() + # Collect Functions and their time offsets for f in retrieve_function_carriers(eq.rhs): - if isinstance(f.function, PointSource): - time_offset = 0 - elif isinstance(f.function, TimeFunction): + if isinstance(f.function, TimeFunction): + # Works but should think of how to improve the derivation time_offset = (f.indices[dim]-dim) % f.function.time_size elif isinstance(f.function, Function): time_offset = 0 else: raise NotImplementedError(f"reading function of type {type(f.function)} not supported") + read_functions.add((f.function, time_offset)) for f, t in read_functions: @@ -387,8 +390,10 @@ def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq): memtemp = UnrealizedConversionCastOp.get([temp], [StencilToMemRefType(temp.type)]).results[0] memtemp.name_hint = temp.name_hint + "_mem" indices = eq.lhs.indices + if isinstance(eq.lhs.function, TimeFunction): indices = indices[1:] + ssa_indices = [self._visit_math_nodes(dim, i, None) for i in indices] for i, ssa_i in enumerate(ssa_indices): if isinstance(ssa_i.type, builtin.IntegerType): @@ -405,36 +410,40 @@ def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq): properties={"kind": attr}) def build_condition(self, dim: SteppingDimension, eq: BooleanFunction): - return self._visit_math_nodes(dim, eq, None) - - def build_generic_step(self, dim: SteppingDimension, eq: LoweredEq): - if eq.conditionals: - condition = And(*eq.conditionals.values(), evaluate=False) - cond = self.build_condition(dim, condition) - if_ = scf.If(cond, (), Region(Block())) - with ImplicitBuilder(if_.true_region.block): - self.build_generic_step_expression(dim, eq) - scf.Yield() - else: - # Build the expression + + assert eq.conditionals + + # Parse condition and build the condition block + condition = And(*eq.conditionals.values(), evaluate=False) + cond = self._visit_math_nodes(dim, condition, None) + + if_ = scf.If(cond, (), Region(Block())) + with ImplicitBuilder(if_.true_region.block): + # Use the builder for the inner expression + assert eq.is_Increment self.build_generic_step_expression(dim, eq) + scf.Yield() + def build_time_loop( self, eqs: list[Any], step_dim: SteppingDimension, **kwargs ): - # Bounds and step boilerpalte + # Bounds and step boilerplate lb = iet_ssa.LoadSymbolic.get( step_dim.symbolic_min._C_name, IndexType() ) ub = iet_ssa.LoadSymbolic.get( step_dim.symbolic_max._C_name, IndexType() ) + one = arith.Constant.from_int_and_width(1, IndexType()) + # Devito iterates from time_m to time_M *inclusive*, MLIR only takes # exclusive upper bounds, so we increment here. ub = arith.Addi(ub, one) # Take the exact time_step from Devito + try: step = arith.Constant.from_int_and_width( int(step_dim.symbolic_incr), IndexType() @@ -461,9 +470,10 @@ def build_time_loop( ) # Name the 'time' step iterator - loop.body.block.args[0].name_hint = "time" + assert step_dim.root.name is 'time' + loop.body.block.args[0].name_hint = step_dim.root.name # Store for later reference - self.symbol_values["time"] = loop.body.block.args[0] + self.symbol_values[step_dim.root.name] = loop.body.block.args[0] # Store a mapping from time_buffers to their corresponding block # arguments for easier access later. @@ -472,12 +482,13 @@ def build_time_loop( for i, (f, t) in enumerate(self.time_buffers) } self.function_values |= self.block_args - # Name the block argument for debugging. + + # Name the block argument for debugging for (f, t), arg in self.block_args.items(): arg.name_hint = f"{f.name}_t{t}" with ImplicitBuilder(loop.body.block): - self.generate_equations(eqs, **kwargs) + self.lower_devito_Eqs(eqs, **kwargs) # Swap buffers through scf.yield yield_args = [ self.block_args[(f, (t + 1) % f.time_size)] @@ -485,16 +496,16 @@ def build_time_loop( ] scf.Yield(*yield_args) - def generate_equations(self, eqs: list[Any], **kwargs): - # Lower equations to their xDSL equivalent + def lower_devito_Eqs(self, eqs: list[Any], **kwargs): + # Lower devito Equations to xDSL + for eq in eqs: + lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) if isinstance(eq, Eq): # Nested lowering? TO re-think approach - lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) for lo in lowered: self._convert_eq(lo) elif isinstance(eq, Injection): - lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) self._lower_injection(lowered) else: raise NotImplementedError(f"Expression {eq} of type {type(eq)} not supported") @@ -683,7 +694,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp: if step_dim is not None: self.build_time_loop(eqs, step_dim, **kwargs) else: - self.generate_equations(eqs, **kwargs) + self.lower_devito_Eqs(eqs, **kwargs) # func wants a return func.Return() diff --git a/devito/ir/xdsl_iet/profiling.py b/devito/ir/xdsl_iet/profiling.py index 1b434d78e0..2dae8ec3df 100644 --- a/devito/ir/xdsl_iet/profiling.py +++ b/devito/ir/xdsl_iet/profiling.py @@ -29,10 +29,9 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): if op.sym_name.data != self.func_name or op in self.seen_ops: return - # only apply once + # Add the op to the set of seen operations self.seen_ops.add(op) - # Insert timer start and end calls # Insert timer start and end calls rewriter.insert_op([ t0 := func.Call('timer_start', [], [builtin.f64]) @@ -41,11 +40,11 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): ret = op.get_return_op() assert ret is not None - rewriter.insert_op_before([ + rewriter.insert_op([ timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.opaque()), t1 := func.Call('timer_end', [t0], [builtin.f64]), llvm.StoreOp(t1, timers), - ], ret) + ], InsertPoint.before(ret)) rewriter.insert_op([ func.FuncOp.external('timer_start', [], [builtin.f64]), @@ -57,8 +56,7 @@ def apply_timers(module, **kwargs): """ Apply timers to a module """ - if kwargs['xdsl_num_sections'] < 1: - return + name = kwargs.get("name", "Kernel") grpa = GreedyRewritePatternApplier([MakeFunctionTimed(name)]) PatternRewriteWalker(grpa, walk_regions_first=True).rewrite_module(module) diff --git a/devito/xdsl_core/xdsl_cpu.py b/devito/xdsl_core/xdsl_cpu.py index 11aa1b037a..6240631e2b 100644 --- a/devito/xdsl_core/xdsl_cpu.py +++ b/devito/xdsl_core/xdsl_cpu.py @@ -54,14 +54,6 @@ def _build(cls, expressions, **kwargs): op = Callable.__new__(cls, **irs.iet.args) Callable.__init__(op, **op.args) - # Header files, etc. - # op._headers = OrderedSet(*cls._default_headers) - # op._headers.update(byproduct.headers) - # op._globals = OrderedSet(*cls._default_globals) - # op._includes = OrderedSet(*cls._default_includes) - # op._includes.update(profiler._default_includes) - # op._includes.update(byproduct.includes) - # Required for the jit-compilation op._compiler = kwargs['compiler'] op._language = kwargs['language'] @@ -80,10 +72,6 @@ def _build(cls, expressions, **kwargs): for i in profiler._ext_calls])) op._func_table.update(OrderedDict([(i.root.name, i) for i in byproduct.funcs])) - # Internal mutable state to store information about previous runs, autotuning - # reports, etc - op._state = cls._initialize_state(**kwargs) - # Produced by the various compilation passes op._reads = filter_sorted(flatten(e.reads for e in irs.expressions)) @@ -91,8 +79,15 @@ def _build(cls, expressions, **kwargs): op._dimensions = set().union(*[e.dimensions for e in irs.expressions]) op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler - kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet)) + + # This has to be moved outside and drop this _build from here + module = cls._lower_stencil(expressions, **kwargs) + + num_sections = len(FindNodes(Section).visit(irs.iet)) + if num_sections: + apply_timers(module, **kwargs) + op._module = module return op @@ -104,11 +99,8 @@ def _lower_stencil(cls, expressions, **kwargs): [Eq] -> [xdsl] Apply timers to the module """ - conv = ExtractDevitoStencilConversion(cls) module = conv.convert(as_tuple(expressions), **kwargs) - # print(module) - apply_timers(module, timed=True, **kwargs) return module