Skip to content

Commit

Permalink
Merge pull request #106 from xdslproject/compiler_edits
Browse files Browse the repository at this point in the history
compiler: Edits pass
  • Loading branch information
georgebisbas authored Jun 26, 2024
2 parents dcf8f0b + 9dd9079 commit 65127f4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 63 deletions.
93 changes: 52 additions & 41 deletions devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from devito.tools.utils import as_tuple
from devito.types.basic import Scalar
from devito.types.dense import DiscreteFunction, Function, TimeFunction
from devito.types.dimension import SpaceDimension, TimeDimension
from devito.types.dimension import SpaceDimension, TimeDimension, ConditionalDimension
from devito.types.equation import Eq

# ------------- xdsl imports -------------#
Expand Down Expand Up @@ -100,10 +100,11 @@ def __init__(self, operator: type[Operator]):
self.temps = dict()
self.operator = operator

def convert_function_eq(self, eq: LoweredEq, **kwargs):
# Read the grid containing necessary discretization information
# (size, halo width, ...)
def lower_Function(self, eq: LoweredEq, **kwargs):

# Get the LHS of the equation, where we write
write_function = eq.lhs.function
# Get Grid and stepping dimension
grid: Grid = write_function.grid
step_dim = grid.stepping_dim

Expand All @@ -118,13 +119,12 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs):

dims = retrieve_dimensions(eq.lhs.indices)

if not all(isinstance(d, (SteppingDimension, SpaceDimension)) for d in dims):
self.build_generic_step(step_dim, eq)
if any(isinstance(d, (ConditionalDimension)) for d in dims):
self.build_condition(step_dim, eq)
else:
# Get the function carriers of the equation
self.build_stencil_step(step_dim, eq)

def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs):
def lower_Symbol(self, symbol: Symbol, rhs: LoweredEq, **kwargs):
"""
Convert a symbol equation to xDSL.
"""
Expand All @@ -135,8 +135,9 @@ def _convert_eq(self, eq: LoweredEq, **kwargs):
"""
# Docs here Need rewriting
This converts a Devito LoweredEq to IR implementing it.
e.g.
Convert a Devito LoweredEq to IR implementing it.
e.g.:
```python
Eq(u[t + 1, x + 1], u[t, x + 1] + 1)
```
Expand All @@ -145,8 +146,8 @@ def _convert_eq(self, eq: LoweredEq, **kwargs):
Grid[extent=(1.0,), shape=(3,), dimensions=(x,)]
```
1. Create a stencil.apply op to implement the equation, with a classical AST
translation.
1. Create a stencil.apply op to implement the equation,
with a classical AST translation.
The example above would be translated to:
```mlir
Expand All @@ -163,21 +164,22 @@ def _convert_eq(self, eq: LoweredEq, **kwargs):
"stencil.store"(%4, %u_t1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<3>} : (!stencil.temp<?xf32>, !stencil.field<[-1,4]xf32>) -> ()
```
"""
# Get the left hand side, called "output function" here because it tells us
# Get the LHS, "write function" telling us where to here because it tells us
# Where to write the results of each step.
write_function = eq.lhs

match write_function:
case Indexed():
match write_function.function:
case TimeFunction() | Function():
self.convert_function_eq(eq, **kwargs)
self.lower_Function(eq, **kwargs)
case _:
type_error = type(write_function.function)
raise NotImplementedError(
f"Function of type {type(write_function.function)} not supported" # noqa
f"Function of type {type_error} not supported"
)
case Symbol():
self.convert_symbol_eq(write_function, eq.rhs, **kwargs)
self.lower_Symbol(write_function, eq.rhs, **kwargs)
case _:
raise NotImplementedError(f"LHS of type {type(write_function)} not supported") # noqa

Expand Down Expand Up @@ -322,15 +324,16 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None:
None
"""
read_functions = OrderedSet()
# Collect Functions and their time offsets
for f in retrieve_function_carriers(eq.rhs):
if isinstance(f.function, PointSource):
time_offset = 0
elif isinstance(f.function, TimeFunction):
if isinstance(f.function, TimeFunction):
# Works but should think of how to improve the derivation
time_offset = (f.indices[dim]-dim) % f.function.time_size
elif isinstance(f.function, Function):
time_offset = 0
else:
raise NotImplementedError(f"reading function of type {type(f.function)} not supported")

read_functions.add((f.function, time_offset))

for f, t in read_functions:
Expand Down Expand Up @@ -387,8 +390,10 @@ def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq):
memtemp = UnrealizedConversionCastOp.get([temp], [StencilToMemRefType(temp.type)]).results[0]
memtemp.name_hint = temp.name_hint + "_mem"
indices = eq.lhs.indices

if isinstance(eq.lhs.function, TimeFunction):
indices = indices[1:]

ssa_indices = [self._visit_math_nodes(dim, i, None) for i in indices]
for i, ssa_i in enumerate(ssa_indices):
if isinstance(ssa_i.type, builtin.IntegerType):
Expand All @@ -405,36 +410,40 @@ def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq):
properties={"kind": attr})

def build_condition(self, dim: SteppingDimension, eq: BooleanFunction):
return self._visit_math_nodes(dim, eq, None)

def build_generic_step(self, dim: SteppingDimension, eq: LoweredEq):
if eq.conditionals:
condition = And(*eq.conditionals.values(), evaluate=False)
cond = self.build_condition(dim, condition)
if_ = scf.If(cond, (), Region(Block()))
with ImplicitBuilder(if_.true_region.block):
self.build_generic_step_expression(dim, eq)
scf.Yield()
else:
# Build the expression

assert eq.conditionals

# Parse condition and build the condition block
condition = And(*eq.conditionals.values(), evaluate=False)
cond = self._visit_math_nodes(dim, condition, None)

if_ = scf.If(cond, (), Region(Block()))
with ImplicitBuilder(if_.true_region.block):
# Use the builder for the inner expression
assert eq.is_Increment
self.build_generic_step_expression(dim, eq)
scf.Yield()


def build_time_loop(
self, eqs: list[Any], step_dim: SteppingDimension, **kwargs
):
# Bounds and step boilerpalte
# Bounds and step boilerplate
lb = iet_ssa.LoadSymbolic.get(
step_dim.symbolic_min._C_name, IndexType()
)
ub = iet_ssa.LoadSymbolic.get(
step_dim.symbolic_max._C_name, IndexType()
)

one = arith.Constant.from_int_and_width(1, IndexType())

# Devito iterates from time_m to time_M *inclusive*, MLIR only takes
# exclusive upper bounds, so we increment here.
ub = arith.Addi(ub, one)

# Take the exact time_step from Devito

try:
step = arith.Constant.from_int_and_width(
int(step_dim.symbolic_incr), IndexType()
Expand All @@ -461,9 +470,10 @@ def build_time_loop(
)

# Name the 'time' step iterator
loop.body.block.args[0].name_hint = "time"
assert step_dim.root.name is 'time'
loop.body.block.args[0].name_hint = step_dim.root.name
# Store for later reference
self.symbol_values["time"] = loop.body.block.args[0]
self.symbol_values[step_dim.root.name] = loop.body.block.args[0]

# Store a mapping from time_buffers to their corresponding block
# arguments for easier access later.
Expand All @@ -472,29 +482,30 @@ def build_time_loop(
for i, (f, t) in enumerate(self.time_buffers)
}
self.function_values |= self.block_args
# Name the block argument for debugging.

# Name the block argument for debugging
for (f, t), arg in self.block_args.items():
arg.name_hint = f"{f.name}_t{t}"

with ImplicitBuilder(loop.body.block):
self.generate_equations(eqs, **kwargs)
self.lower_devito_Eqs(eqs, **kwargs)
# Swap buffers through scf.yield
yield_args = [
self.block_args[(f, (t + 1) % f.time_size)]
for (f, t) in self.block_args.keys()
]
scf.Yield(*yield_args)

def generate_equations(self, eqs: list[Any], **kwargs):
# Lower equations to their xDSL equivalent
def lower_devito_Eqs(self, eqs: list[Any], **kwargs):
# Lower devito Equations to xDSL

for eq in eqs:
lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs)
if isinstance(eq, Eq):
# Nested lowering? TO re-think approach
lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs)
for lo in lowered:
self._convert_eq(lo)
elif isinstance(eq, Injection):
lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs)
self._lower_injection(lowered)
else:
raise NotImplementedError(f"Expression {eq} of type {type(eq)} not supported")
Expand Down Expand Up @@ -683,7 +694,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
if step_dim is not None:
self.build_time_loop(eqs, step_dim, **kwargs)
else:
self.generate_equations(eqs, **kwargs)
self.lower_devito_Eqs(eqs, **kwargs)

# func wants a return
func.Return()
Expand Down
10 changes: 4 additions & 6 deletions devito/ir/xdsl_iet/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
if op.sym_name.data != self.func_name or op in self.seen_ops:
return

# only apply once
# Add the op to the set of seen operations
self.seen_ops.add(op)

# Insert timer start and end calls
# Insert timer start and end calls
rewriter.insert_op([
t0 := func.Call('timer_start', [], [builtin.f64])
Expand All @@ -41,11 +40,11 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
ret = op.get_return_op()
assert ret is not None

rewriter.insert_op_before([
rewriter.insert_op([
timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.opaque()),
t1 := func.Call('timer_end', [t0], [builtin.f64]),
llvm.StoreOp(t1, timers),
], ret)
], InsertPoint.before(ret))

rewriter.insert_op([
func.FuncOp.external('timer_start', [], [builtin.f64]),
Expand All @@ -57,8 +56,7 @@ def apply_timers(module, **kwargs):
"""
Apply timers to a module
"""
if kwargs['xdsl_num_sections'] < 1:
return

name = kwargs.get("name", "Kernel")
grpa = GreedyRewritePatternApplier([MakeFunctionTimed(name)])
PatternRewriteWalker(grpa, walk_regions_first=True).rewrite_module(module)
24 changes: 8 additions & 16 deletions devito/xdsl_core/xdsl_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ def _build(cls, expressions, **kwargs):
op = Callable.__new__(cls, **irs.iet.args)
Callable.__init__(op, **op.args)

# Header files, etc.
# op._headers = OrderedSet(*cls._default_headers)
# op._headers.update(byproduct.headers)
# op._globals = OrderedSet(*cls._default_globals)
# op._includes = OrderedSet(*cls._default_includes)
# op._includes.update(profiler._default_includes)
# op._includes.update(byproduct.includes)

# Required for the jit-compilation
op._compiler = kwargs['compiler']
op._language = kwargs['language']
Expand All @@ -80,19 +72,22 @@ def _build(cls, expressions, **kwargs):
for i in profiler._ext_calls]))
op._func_table.update(OrderedDict([(i.root.name, i) for i in byproduct.funcs]))

# Internal mutable state to store information about previous runs, autotuning
# reports, etc
op._state = cls._initialize_state(**kwargs)

# Produced by the various compilation passes

op._reads = filter_sorted(flatten(e.reads for e in irs.expressions))
op._writes = filter_sorted(flatten(e.writes for e in irs.expressions))
op._dimensions = set().union(*[e.dimensions for e in irs.expressions])
op._dtype, op._dspace = irs.clusters.meta
op._profiler = profiler
kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet))

# This has to be moved outside and drop this _build from here

module = cls._lower_stencil(expressions, **kwargs)

num_sections = len(FindNodes(Section).visit(irs.iet))
if num_sections:
apply_timers(module, **kwargs)

op._module = module

return op
Expand All @@ -104,11 +99,8 @@ def _lower_stencil(cls, expressions, **kwargs):
[Eq] -> [xdsl]
Apply timers to the module
"""

conv = ExtractDevitoStencilConversion(cls)
module = conv.convert(as_tuple(expressions), **kwargs)
# print(module)
apply_timers(module, timed=True, **kwargs)

return module

Expand Down

0 comments on commit 65127f4

Please sign in to comment.