Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Introduce Source injection #95

Merged
merged 18 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir-mpi-openmp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33

- name: Test with MPI + openmp
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir-mpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33

- name: Test with MPI - no Openmp
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-mlir-openmp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33

- name: Test no-MPI, Openmp
run: |
export DEVITO_MPI=0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33

- name: Test no-MPI, no-Openmp
run: |
Expand Down
11 changes: 5 additions & 6 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.passes.equations import collect_derivatives
from devito.tools import timed_pass

from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_hyperplanes,
optimize_pows)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, hoist_prodders,
linearize, mpiize, relax_incr_dimensions)
factorize, fission, fuse, optimize_pows,
optimize_hyperplanes)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize,
hoist_prodders, relax_incr_dimensions)
from devito.tools import timed_pass


__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
Expand Down
68 changes: 37 additions & 31 deletions devito/core/cpu_xdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from devito.logger import info, perf
from devito.mpi import MPI
from devito.operator.profiling import create_profile
from devito.tools import filter_sorted, flatten, OrderedSet
from devito.tools import filter_sorted, flatten, as_tuple
from devito.types import TimeFunction
from devito.types.dense import DiscreteFunction, Function
from devito.types.mlir_types import f32, ptr_of
Expand All @@ -33,6 +33,8 @@
from devito.passes.iet import CTarget, OmpTarget
from devito.core.cpu import Cpu64OperatorMixin

from examples.seismic.source import PointSource

__all__ = ['XdslnoopOperator', 'XdslAdvOperator']


Expand All @@ -57,12 +59,12 @@ def _build(cls, expressions, **kwargs):
Callable.__init__(op, **op.args)

# Header files, etc.
op._headers = OrderedSet(*cls._default_headers)
op._headers.update(byproduct.headers)
op._globals = OrderedSet(*cls._default_globals)
op._includes = OrderedSet(*cls._default_includes)
op._includes.update(profiler._default_includes)
op._includes.update(byproduct.includes)
# op._headers = OrderedSet(*cls._default_headers)
# op._headers.update(byproduct.headers)
# op._globals = OrderedSet(*cls._default_globals)
# op._includes = OrderedSet(*cls._default_includes)
# op._includes.update(profiler._default_includes)
# op._includes.update(byproduct.includes)

# Required for the jit-compilation
op._compiler = kwargs['compiler']
Expand Down Expand Up @@ -94,7 +96,7 @@ def _build(cls, expressions, **kwargs):
op._dtype, op._dspace = irs.clusters.meta
op._profiler = profiler
kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet))
module = cls._lower_stencil(irs.expressions, **kwargs)
module = cls._lower_stencil(expressions, **kwargs)
op._module = module

return op
Expand All @@ -107,8 +109,8 @@ def _lower_stencil(cls, expressions, **kwargs):
Apply timers to the module
"""

conv = ExtractDevitoStencilConversion()
module = conv.convert(expressions, **kwargs)
conv = ExtractDevitoStencilConversion(cls)
module = conv.convert(as_tuple(expressions), **kwargs)
# print(module)
apply_timers(module, timed=True, **kwargs)

Expand Down Expand Up @@ -309,9 +311,9 @@ def cfunction(self):
if self._cfunction is None:
self._cfunction = getattr(self._lib, self.name)
# Associate a C type to each argument for runtime type check
argtypes = self._construct_cfunction_args(self._jit_kernel_constants,
get_types=True)
self._cfunction.argtypes = argtypes
# argtypes = self._construct_cfunction_args(self._jit_kernel_constants,
# get_types=True)
# self._cfunction.argtypes = argtypes

return self._cfunction

Expand Down Expand Up @@ -356,38 +358,42 @@ def setup_memref_args(self):
data = arg._data
for t in range(data.shape[0]):
args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32))
if isinstance(arg, Function):
args[f'{arg._C_name}'] = arg._data[...].ctypes.data_as(ptr_of(f32))
elif isinstance(arg, Function):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))

elif isinstance(arg, PointSource):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))
else:
raise NotImplementedError(f"type {type(arg)} not implemented")

self._jit_kernel_constants.update(args)

def _construct_cfunction_args(self, args, get_types=False):
"""
Either construct the args for the cfunction, or construct the
arg types for it.
"""
ps = {
p._C_name: p._C_ctype for p in self.parameters
}
def _construct_cfunction_types(self, args):
ps = {p._C_name: p._C_ctype for p in self.parameters}

objects = []
objects_types = []

for name in get_arg_names_from_module(self._module):
object = args[name]
objects.append(object)
if name in ps:
object_type = ps[name]
if object_type == DiscreteFunction._C_ctype:
object_type = dict(object_type._type_._fields_)['data']
objects_types.append(object_type)
else:
objects_types.append(type(object))
return objects_types

def _construct_cfunction_args(self, args):
"""
Either construct the args for the cfunction, or construct the
arg types for it.
"""

objects = []
for name in get_arg_names_from_module(self._module):
object = args[name]
objects.append(object)

if get_types:
return objects_types
else:
return objects
return objects


class XdslAdvOperator(XdslnoopOperator):
Expand Down
Loading
Loading