-
Notifications
You must be signed in to change notification settings - Fork 0
Coupled #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Coupled #41
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |
| c_restrict_void_p, sorted_priority) | ||
| from devito.types.basic import AbstractFunction, Basic | ||
| from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer, | ||
| IndexedData, DeviceMap) | ||
| IndexedData, DeviceMap, LocalCompositeObject) | ||
|
|
||
|
|
||
| __all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols', | ||
|
|
@@ -190,7 +190,7 @@ def __init__(self, *args, compiler=None, **kwargs): | |
|
|
||
| def _gen_struct_decl(self, obj, masked=()): | ||
| """ | ||
| Convert ctypes.Struct -> cgen.Structure. | ||
| Convert ctypes.Struct and LocalCompositeObject -> cgen.Structure. | ||
| """ | ||
| ctype = obj._C_ctype | ||
| try: | ||
|
|
@@ -201,7 +201,16 @@ def _gen_struct_decl(self, obj, masked=()): | |
| return None | ||
| except TypeError: | ||
| # E.g., `ctype` is of type `dtypes_lowering.CustomDtype` | ||
| return None | ||
| if isinstance(obj, LocalCompositeObject): | ||
| # TODO: Potentially re-evaluate: Setting ctype to obj allows | ||
| # _gen_struct_decl to generate a cgen.Structure from a | ||
| # LocalCompositeObject, where obj._C_ctype is a CustomDtype. | ||
| # LocalCompositeObject has a __fields__ property, | ||
| # which allows the subsequent code in this function to function | ||
| # correctly. | ||
| ctype = obj | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how can the ctype being a LocalCompositeObject? I struggle to follow here |
||
| else: | ||
| return None | ||
|
|
||
| try: | ||
| return obj._C_typedecl | ||
|
|
@@ -678,8 +687,11 @@ def _operator_typedecls(self, o, mode='all'): | |
| for i in o._func_table.values(): | ||
| if not i.local: | ||
| continue | ||
| typedecls.extend([self._gen_struct_decl(j) for j in i.root.parameters | ||
| if xfilter(j)]) | ||
| typedecls.extend([ | ||
| self._gen_struct_decl(j) | ||
| for j in FindSymbols().visit(i.root) | ||
| if xfilter(j) | ||
| ]) | ||
| typedecls = filter_sorted(typedecls, key=lambda i: i.tpname) | ||
|
|
||
| return typedecls | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,27 @@ | ||
| import cgen as c | ||
| import numpy as np | ||
| from functools import cached_property | ||
|
|
||
| from devito.passes.iet.engine import iet_pass | ||
| from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine, | ||
| FindNodes, Call, CallableBody) | ||
| from devito.symbolics import Byref, Macro | ||
| DummyExpr, CallableBody, List, Call, Callable, | ||
| FindNodes) | ||
| from devito.symbolics import Byref, Macro, FieldFromPointer | ||
| from devito.types import Symbol, Scalar | ||
| from devito.types.basic import DataSymbol | ||
| from devito.petsc.types import (PetscMPIInt, PetscErrorCode, Initialize, | ||
| Finalize, ArgvSymbol) | ||
| from devito.tools import frozendict | ||
| from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData, | ||
| PointerIS, Mat, LocalVec, GlobalVec, CallbackMat, SNES, | ||
| DummyArg, PetscInt, PointerDM, PointerMat, MatReuse, | ||
| CallbackPointerIS, CallbackPointerDM, JacobianStruct, | ||
| SubMatrixStruct, Initialize, Finalize, ArgvSymbol) | ||
| from devito.petsc.types.macros import petsc_func_begin_user | ||
| from devito.petsc.iet.nodes import PetscMetaData | ||
| from devito.petsc.utils import core_metadata | ||
| from devito.petsc.iet.routines import (CallbackBuilder, BaseObjectBuilder, BaseSetup, | ||
| Solver, TimeDependent, NonTimeDependent) | ||
| from devito.petsc.iet.routines import (CBBuilder, CCBBuilder, BaseObjectBuilder, | ||
| CoupledObjectBuilder, BaseSetup, CoupledSetup, | ||
| Solver, CoupledSolver, TimeDependent, | ||
| NonTimeDependent) | ||
| from devito.petsc.iet.utils import petsc_call, petsc_call_mpi | ||
|
|
||
|
|
||
|
|
@@ -26,7 +35,6 @@ def lower_petsc(iet, **kwargs): | |
| return iet, {} | ||
|
|
||
| metadata = core_metadata() | ||
|
|
||
| data = FindNodes(PetscMetaData).visit(iet) | ||
|
|
||
| if any(filter(lambda i: isinstance(i.expr.rhs, Initialize), data)): | ||
|
|
@@ -35,10 +43,10 @@ def lower_petsc(iet, **kwargs): | |
| if any(filter(lambda i: isinstance(i.expr.rhs, Finalize), data)): | ||
| return finalize(iet), metadata | ||
|
|
||
| targets = [i.expr.rhs.target for (i,) in injectsolve_mapper.values()] | ||
|
|
||
| # Assumption is that all targets have the same grid so can use any target here | ||
| objs = build_core_objects(targets[-1], **kwargs) | ||
| unique_grids = {i.expr.rhs.grid for (i,) in injectsolve_mapper.values()} | ||
| # Assumption is that all solves are on the same grid | ||
| if len(unique_grids) > 1: | ||
| raise ValueError("All PETScSolves must use the same Grid, but multiple found.") | ||
|
|
||
| # Create core PETSc calls (not specific to each PETScSolve) | ||
| core = make_core_petsc_calls(objs, **kwargs) | ||
|
|
@@ -54,17 +62,18 @@ def lower_petsc(iet, **kwargs): | |
| setup.extend(builder.solversetup.calls) | ||
|
|
||
| # Transform the spatial iteration loop with the calls to execute the solver | ||
| subs.update(builder.solve.mapper) | ||
| subs.update({builder.solve.spatial_body: builder.solve.calls}) | ||
|
|
||
| efuncs.update(builder.cbbuilder.efuncs) | ||
|
|
||
| populate_matrix_context(efuncs, objs) | ||
|
|
||
| iet = Transformer(subs).visit(iet) | ||
|
|
||
| body = core + tuple(setup) + (BlankLine,) + iet.body.body | ||
| body = iet.body._rebuild(body=body) | ||
| iet = iet._rebuild(body=body) | ||
| metadata.update({'efuncs': tuple(efuncs.values())}) | ||
|
|
||
| return iet, metadata | ||
|
|
||
|
|
||
|
|
@@ -100,56 +109,140 @@ def make_core_petsc_calls(objs, **kwargs): | |
| return call_mpi, BlankLine | ||
|
|
||
|
|
||
| def build_core_objects(target, **kwargs): | ||
| communicator = 'PETSC_COMM_WORLD' | ||
|
|
||
| return { | ||
| 'size': PetscMPIInt(name='size'), | ||
| 'comm': communicator, | ||
| 'err': PetscErrorCode(name='err'), | ||
| 'grid': target.grid | ||
| } | ||
|
|
||
|
|
||
| class Builder: | ||
| """ | ||
| This class is designed to support future extensions, enabling | ||
| different combinations of solver types, preconditioning methods, | ||
| and other functionalities as needed. | ||
|
|
||
| The class will be extended to accommodate different solver types by | ||
| returning subclasses of the objects initialised in __init__, | ||
| depending on the properties of `injectsolve`. | ||
| """ | ||
| def __init__(self, injectsolve, objs, iters, **kwargs): | ||
| self.injectsolve = injectsolve | ||
| self.objs = objs | ||
| self.iters = iters | ||
| self.kwargs = kwargs | ||
| self.coupled = isinstance(injectsolve.expr.rhs.fielddata, MultipleFieldData) | ||
| self.args = { | ||
| 'injectsolve': self.injectsolve, | ||
| 'objs': self.objs, | ||
| 'iters': self.iters, | ||
| **self.kwargs | ||
| } | ||
| self.args['solver_objs'] = self.objbuilder.solver_objs | ||
| self.args['timedep'] = self.timedep | ||
| self.args['cbbuilder'] = self.cbbuilder | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are any of these functions sensitive to the contents of
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? |
||
|
|
||
| @cached_property | ||
| def objbuilder(self): | ||
| return ( | ||
| CoupledObjectBuilder(**self.args) | ||
| if self.coupled else | ||
| BaseObjectBuilder(**self.args) | ||
| ) | ||
|
|
||
| # Determine the time dependency class | ||
| time_mapper = injectsolve.expr.rhs.time_mapper | ||
| timedep = TimeDependent if time_mapper else NonTimeDependent | ||
| self.timedep = timedep(injectsolve, iters, **kwargs) | ||
| @cached_property | ||
| def timedep(self): | ||
| time_mapper = self.injectsolve.expr.rhs.time_mapper | ||
| timedep_class = TimeDependent if time_mapper else NonTimeDependent | ||
| return timedep_class(**self.args) | ||
|
|
||
| # Objects | ||
| self.objbuilder = BaseObjectBuilder(injectsolve, **kwargs) | ||
| self.solver_objs = self.objbuilder.solver_objs | ||
| @cached_property | ||
| def cbbuilder(self): | ||
| return CCBBuilder(**self.args) if self.coupled else CBBuilder(**self.args) | ||
|
|
||
| # Callbacks | ||
| self.cbbuilder = CallbackBuilder( | ||
| injectsolve, objs, self.solver_objs, timedep=self.timedep, | ||
| **kwargs | ||
| ) | ||
| @cached_property | ||
| def solversetup(self): | ||
| return CoupledSetup(**self.args) if self.coupled else BaseSetup(**self.args) | ||
|
|
||
| # Solver setup | ||
| self.solversetup = BaseSetup( | ||
| self.solver_objs, objs, injectsolve, self.cbbuilder | ||
| ) | ||
| @cached_property | ||
| def solve(self): | ||
| return CoupledSolver(**self.args) if self.coupled else Solver(**self.args) | ||
|
|
||
| # Execute the solver | ||
| self.solve = Solver( | ||
| self.solver_objs, objs, injectsolve, iters, | ||
| self.cbbuilder, timedep=self.timedep | ||
| ) | ||
|
|
||
| def populate_matrix_context(efuncs, objs): | ||
| if not objs['dummyefunc'] in efuncs.values(): | ||
| return | ||
|
|
||
| subdms_expr = DummyExpr( | ||
| FieldFromPointer(objs['Subdms']._C_symbol, objs['ljacctx']), | ||
| objs['Subdms']._C_symbol | ||
| ) | ||
| fields_expr = DummyExpr( | ||
| FieldFromPointer(objs['Fields']._C_symbol, objs['ljacctx']), | ||
| objs['Fields']._C_symbol | ||
| ) | ||
| body = CallableBody( | ||
| List(body=[subdms_expr, fields_expr]), | ||
| init=(objs['begin_user'],), | ||
| retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) | ||
| ) | ||
| name = 'PopulateMatContext' | ||
| efuncs[name] = Callable( | ||
| name, body, objs['err'], | ||
| parameters=[objs['ljacctx'], objs['Subdms'], objs['Fields']] | ||
| ) | ||
|
|
||
|
|
||
| # Move these to types folder | ||
| # TODO: Devito MPI + PETSc testing | ||
| # if kwargs['options']['mpi'] -> communicator = grid.distributor._obj_comm | ||
| communicator = 'PETSC_COMM_WORLD' | ||
| subdms = PointerDM(name='subdms') | ||
| fields = PointerIS(name='fields') | ||
| submats = PointerMat(name='submats') | ||
| rows = PointerIS(name='rows') | ||
| cols = PointerIS(name='cols') | ||
|
|
||
|
|
||
| # A static dict containing shared symbols and objects that are not | ||
| # unique to each PETScSolve. | ||
| # Many of these objects are used as arguments in callback functions to make | ||
| # the C code cleaner and more modular. This is also a step toward leveraging | ||
| # Devito's `reuse_efuncs` functionality, allowing reuse of efuncs when | ||
| # they are semantically identical. | ||
| objs = frozendict({ | ||
| 'size': PetscMPIInt(name='size'), | ||
| 'comm': communicator, | ||
| 'err': PetscErrorCode(name='err'), | ||
| 'block': CallbackMat('block'), | ||
| 'submat_arr': PointerMat(name='submat_arr'), | ||
| 'subblockrows': PetscInt('subblockrows'), | ||
| 'subblockcols': PetscInt('subblockcols'), | ||
| 'rowidx': PetscInt('rowidx'), | ||
| 'colidx': PetscInt('colidx'), | ||
| 'J': Mat('J'), | ||
| 'X': GlobalVec('X'), | ||
| 'xloc': LocalVec('xloc'), | ||
| 'Y': GlobalVec('Y'), | ||
| 'yloc': LocalVec('yloc'), | ||
| 'F': GlobalVec('F'), | ||
| 'floc': LocalVec('floc'), | ||
| 'B': GlobalVec('B'), | ||
| 'nfields': PetscInt('nfields'), | ||
| 'irow': PointerIS(name='irow'), | ||
| 'icol': PointerIS(name='icol'), | ||
| 'nsubmats': Scalar('nsubmats', dtype=np.int32), | ||
| 'matreuse': MatReuse('scall'), | ||
| 'snes': SNES('snes'), | ||
| 'rows': rows, | ||
| 'cols': cols, | ||
| 'Subdms': subdms, | ||
| 'LocalSubdms': CallbackPointerDM(name='subdms'), | ||
| 'Fields': fields, | ||
| 'LocalFields': CallbackPointerIS(name='fields'), | ||
| 'Submats': submats, | ||
| 'ljacctx': JacobianStruct( | ||
| fields=[subdms, fields, submats], modifier=' *' | ||
| ), | ||
| 'subctx': SubMatrixStruct(fields=[rows, cols]), | ||
| 'Null': Macro('NULL'), | ||
| 'dummyctx': Symbol('lctx'), | ||
| 'dummyptr': DummyArg('dummy'), | ||
| 'dummyefunc': Symbol('dummyefunc'), | ||
| 'dof': PetscInt('dof'), | ||
| 'begin_user': c.Line('PetscFunctionBeginUser;'), | ||
| }) | ||
|
|
||
| # Move to macros file? | ||
| Null = Macro('NULL') | ||
| void = 'void' | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment to be moved inside the else?