diff --git a/devito/petsc/iet/passes.py b/devito/petsc/iet/passes.py index 7d7c156dcb..561f9fd0f4 100644 --- a/devito/petsc/iet/passes.py +++ b/devito/petsc/iet/passes.py @@ -1,16 +1,13 @@ import cgen as c from devito.passes.iet.engine import iet_pass -from devito.ir.iet import (Transformer, MapNodes, Iteration, List, BlankLine, - DummyExpr, FindNodes, retrieve_iteration_tree, - filter_iterations) -from devito.symbolics import Byref, Macro, FieldFromComposite -from devito.petsc.types import (PetscMPIInt, DM, Mat, LocalVec, GlobalVec, - KSP, PC, SNES, PetscErrorCode, DummyArg, PetscInt, - StartPtr) -from devito.petsc.iet.nodes import InjectSolveDummy, PETScCall -from devito.petsc.utils import solver_mapper, core_metadata -from devito.petsc.iet.routines import PETScCallbackBuilder +from devito.ir.iet import Transformer, MapNodes, Iteration, BlankLine +from devito.symbolics import Byref, Macro +from devito.petsc.types import (PetscMPIInt, PetscErrorCode) +from devito.petsc.iet.nodes import InjectSolveDummy +from devito.petsc.utils import core_metadata +from devito.petsc.iet.routines import (CallbackBuilder, BaseObjectBuilder, BaseSetup, + Solver, TimeDependent, NonTimeDependent) from devito.petsc.iet.utils import petsc_call, petsc_call_mpi @@ -34,54 +31,36 @@ def lower_petsc(iet, **kwargs): setup = [] subs = {} - - # Create a different DMDA for each target with a unique space order - unique_dmdas = create_dmda_objs(targets) - objs.update(unique_dmdas) - for dmda in unique_dmdas.values(): - setup.extend(create_dmda_calls(dmda, objs)) - - builder = PETScCallbackBuilder(**kwargs) + efuncs = {} for iters, (injectsolve,) in injectsolve_mapper.items(): - solver_objs = build_solver_objs(injectsolve, iters, **kwargs) - # Generate the solver setup for each InjectSolveDummy - solver_setup = generate_solver_setup(solver_objs, objs, injectsolve) - setup.extend(solver_setup) + builder = Builder(injectsolve, objs, iters, **kwargs) - # Generate all PETSc callback functions for the target via recursive compilation - matvec_op, formfunc_op, runsolve = builder.make(injectsolve, - objs, solver_objs) - setup.extend([matvec_op, formfunc_op, BlankLine]) - # Only Transform the spatial iteration loop - space_iter, = spatial_injectsolve_iter(iters, injectsolve) - subs.update({space_iter: List(body=runsolve)}) + setup.extend(builder.solversetup.calls) - # Generate callback to populate main struct object - struct, struct_calls = builder.make_main_struct(unique_dmdas, objs) - setup.extend(struct_calls) + # Transform the spatial iteration loop with the calls to execute the solver + subs.update(builder.solve.mapper) - iet = Transformer(subs).visit(iet) + efuncs.update(builder.cbbuilder.efuncs) - iet = assign_time_iters(iet, struct) + iet = Transformer(subs).visit(iet) body = core + tuple(setup) + (BlankLine,) + iet.body.body body = iet.body._rebuild( init=init, body=body, - frees=(c.Line("PetscCall(PetscFinalize());"),) + frees=(petsc_call('PetscFinalize', []),) ) iet = iet._rebuild(body=body) metadata = core_metadata() - efuncs = tuple(builder.efuncs.values()) - metadata.update({'efuncs': efuncs}) + metadata.update({'efuncs': tuple(efuncs.values())}) return iet, metadata def init_petsc(**kwargs): # Initialize PETSc -> for now, assuming all solver options have to be - # specifed via the parameters dict in PETScSolve + # specified via the parameters dict in PETScSolve # TODO: Are users going to be able to use PETSc command line arguments? # In firedrake, they have an options_prefix for each solver, enabling the use # of command line options @@ -110,201 +89,48 @@ def build_core_objects(target, **kwargs): } -def create_dmda_objs(unique_targets): - unique_dmdas = {} - for target in unique_targets: - name = 'da_so_%s' % target.space_order - unique_dmdas[name] = DM(name=name, liveness='eager', - stencil_width=target.space_order) - return unique_dmdas - - -def create_dmda_calls(dmda, objs): - dmda_create = create_dmda(dmda, objs) - dm_setup = petsc_call('DMSetUp', [dmda]) - dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL']) - dm_get_local_info = petsc_call('DMDAGetLocalInfo', [dmda, Byref(dmda.info)]) - return dmda_create, dm_setup, dm_mat_type, dm_get_local_info, BlankLine - - -def create_dmda(dmda, objs): - no_of_space_dims = len(objs['grid'].dimensions) - - # MPI communicator - args = [objs['comm']] - - # Type of ghost nodes - args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(no_of_space_dims)]) - - # Stencil type - if no_of_space_dims > 1: - args.append('DMDA_STENCIL_BOX') - - # Global dimensions - args.extend(list(objs['grid'].shape)[::-1]) - # No.of processors in each dimension - if no_of_space_dims > 1: - args.extend(list(objs['grid'].distributor.topology)[::-1]) - - # Number of degrees of freedom per node - args.append(1) - # "Stencil width" -> size of overlap - args.append(dmda.stencil_width) - args.extend([Null for _ in range(no_of_space_dims)]) - - # The distributed array object - args.append(Byref(dmda)) - - # The PETSc call used to create the DMDA - dmda = petsc_call('DMDACreate%sd' % no_of_space_dims, args) - - return dmda - - -def build_solver_objs(injectsolve, iters, **kwargs): - target = injectsolve.expr.rhs.target - sreg = kwargs['sregistry'] - return { - 'Jac': Mat(sreg.make_name(prefix='J_')), - 'x_global': GlobalVec(sreg.make_name(prefix='x_global_')), - 'x_local': LocalVec(sreg.make_name(prefix='x_local_'), liveness='eager'), - 'b_global': GlobalVec(sreg.make_name(prefix='b_global_')), - 'b_local': LocalVec(sreg.make_name(prefix='b_local_')), - 'ksp': KSP(sreg.make_name(prefix='ksp_')), - 'pc': PC(sreg.make_name(prefix='pc_')), - 'snes': SNES(sreg.make_name(prefix='snes_')), - 'X_global': GlobalVec(sreg.make_name(prefix='X_global_')), - 'Y_global': GlobalVec(sreg.make_name(prefix='Y_global_')), - 'X_local': LocalVec(sreg.make_name(prefix='X_local_'), liveness='eager'), - 'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'), - 'dummy': DummyArg(sreg.make_name(prefix='dummy_')), - 'localsize': PetscInt(sreg.make_name(prefix='localsize_')), - 'start_ptr': StartPtr(sreg.make_name(prefix='start_ptr_'), target.dtype), - 'true_dims': retrieve_time_dims(iters), - 'target': target, - 'time_mapper': injectsolve.expr.rhs.time_mapper, - } - - -def generate_solver_setup(solver_objs, objs, injectsolve): - target = solver_objs['target'] - - dmda = objs['da_so_%s' % target.space_order] - - solver_params = injectsolve.expr.rhs.solver_parameters - - snes_create = petsc_call('SNESCreate', [objs['comm'], Byref(solver_objs['snes'])]) - - snes_set_dm = petsc_call('SNESSetDM', [solver_objs['snes'], dmda]) - - create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(solver_objs['Jac'])]) - - # NOTE: Assumming all solves are linear for now. - snes_set_type = petsc_call('SNESSetType', [solver_objs['snes'], 'SNESKSPONLY']) - - snes_set_jac = petsc_call( - 'SNESSetJacobian', [solver_objs['snes'], solver_objs['Jac'], - solver_objs['Jac'], 'MatMFFDComputeJacobian', Null] - ) - - global_x = petsc_call('DMCreateGlobalVector', - [dmda, Byref(solver_objs['x_global'])]) - - global_b = petsc_call('DMCreateGlobalVector', - [dmda, Byref(solver_objs['b_global'])]) - - local_b = petsc_call('DMCreateLocalVector', - [dmda, Byref(solver_objs['b_local'])]) +class Builder: + """ + This class is designed to support future extensions, enabling + different combinations of solver types, preconditioning methods, + and other functionalities as needed. - snes_get_ksp = petsc_call('SNESGetKSP', - [solver_objs['snes'], Byref(solver_objs['ksp'])]) + The class will be extended to accommodate different solver types by + returning subclasses of the objects initialised in __init__, + depending on the properties of `injectsolve`. + """ + def __init__(self, injectsolve, objs, iters, **kwargs): - ksp_set_tols = petsc_call( - 'KSPSetTolerances', [solver_objs['ksp'], solver_params['ksp_rtol'], - solver_params['ksp_atol'], solver_params['ksp_divtol'], - solver_params['ksp_max_it']] - ) + # Determine the time dependency class + time_mapper = injectsolve.expr.rhs.time_mapper + timedep = TimeDependent if time_mapper else NonTimeDependent + self.timedep = timedep(injectsolve, iters, **kwargs) - ksp_set_type = petsc_call( - 'KSPSetType', [solver_objs['ksp'], solver_mapper[solver_params['ksp_type']]] - ) + # Objects + self.objbuilder = BaseObjectBuilder(injectsolve, **kwargs) + self.solver_objs = self.objbuilder.solver_objs - ksp_get_pc = petsc_call('KSPGetPC', [solver_objs['ksp'], Byref(solver_objs['pc'])]) - - # Even though the default will be jacobi, set to PCNONE for now - pc_set_type = petsc_call('PCSetType', [solver_objs['pc'], 'PCNONE']) - - ksp_set_from_ops = petsc_call('KSPSetFromOptions', [solver_objs['ksp']]) - - return ( - snes_create, - snes_set_dm, - create_matrix, - snes_set_jac, - snes_set_type, - global_x, - global_b, - local_b, - snes_get_ksp, - ksp_set_tols, - ksp_set_type, - ksp_get_pc, - pc_set_type, - ksp_set_from_ops - ) + # Callbacks + self.cbbuilder = CallbackBuilder( + injectsolve, objs, self.solver_objs, timedep=self.timedep, + **kwargs + ) + # Solver setup + self.solversetup = BaseSetup( + self.solver_objs, objs, injectsolve, self.cbbuilder + ) -def assign_time_iters(iet, struct): - """ - Assign time iterators to the struct within loops containing PETScCalls. - Ensure that assignment occurs only once per time loop, if necessary. - Assign only the iterators that are common between the struct fields - and the actual Iteration. - """ - time_iters = [ - i for i in FindNodes(Iteration).visit(iet) - if i.dim.is_Time and FindNodes(PETScCall).visit(i) - ] - - if not time_iters: - return iet - - mapper = {} - for iter in time_iters: - common_dims = [d for d in iter.dimensions if d in struct.fields] - common_dims = [ - DummyExpr(FieldFromComposite(d, struct), d) for d in common_dims - ] - iter_new = iter._rebuild(nodes=List(body=tuple(common_dims)+iter.nodes)) - mapper.update({iter: iter_new}) - - return Transformer(mapper).visit(iet) - - -def retrieve_time_dims(iters): - time_iter = [i for i in iters if any(d.is_Time for d in i.dimensions)] - mapper = {} - if not time_iter: - return mapper - for d in time_iter[0].dimensions: - if d.is_Modulo: - mapper[d.origin] = d - elif d.is_Time: - mapper[d] = d - return mapper - - -def spatial_injectsolve_iter(iter, injectsolve): - spatial_body = [] - for tree in retrieve_iteration_tree(iter[0]): - root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0] - if injectsolve in FindNodes(InjectSolveDummy).visit(root): - spatial_body.append(root) - return spatial_body + # Execute the solver + self.solve = Solver( + self.solver_objs, objs, injectsolve, iters, + self.cbbuilder, timedep=self.timedep + ) Null = Macro('NULL') void = 'void' + # TODO: Don't use c.Line here? petsc_func_begin_user = c.Line('PetscFunctionBeginUser;') diff --git a/devito/petsc/iet/routines.py b/devito/petsc/iet/routines.py index a516b1bc81..81547ac64b 100644 --- a/devito/petsc/iet/routines.py +++ b/devito/petsc/iet/routines.py @@ -3,31 +3,49 @@ import cgen as c from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody, - Dereference, DummyExpr, BlankLine, Callable) -from devito.symbolics import Byref, FieldFromPointer, Macro, cast_mapper + Dereference, DummyExpr, BlankLine, Callable, FindNodes, + retrieve_iteration_tree, filter_iterations) +from devito.symbolics import (Byref, FieldFromPointer, Macro, cast_mapper, + FieldFromComposite) from devito.symbolics.unevaluation import Mul from devito.types.basic import AbstractFunction -from devito.types import ModuloDimension, TimeDimension, Temp +from devito.types import Temp, Symbol from devito.tools import filter_ordered + from devito.petsc.types import PETScArray from devito.petsc.iet.nodes import (PETScCallable, FormFunctionCallback, - MatVecCallback) + MatVecCallback, InjectSolveDummy) from devito.petsc.iet.utils import petsc_call, petsc_struct -from devito.ir.support import SymbolRegistry +from devito.petsc.utils import solver_mapper +from devito.petsc.types import (DM, CallbackDM, Mat, LocalVec, GlobalVec, KSP, PC, + SNES, DummyArg, PetscInt, StartPtr) -class PETScCallbackBuilder: +class CallbackBuilder: """ Build IET routines to generate PETSc callback functions. """ - def __new__(cls, rcompile=None, sregistry=None, **kwargs): - obj = object.__new__(cls) - obj.rcompile = rcompile - obj.sregistry = sregistry - obj._efuncs = OrderedDict() - obj._struct_params = [] + def __init__(self, injectsolve, objs, solver_objs, + rcompile=None, sregistry=None, timedep=None, **kwargs): + + self.rcompile = rcompile + self.sregistry = sregistry + self.timedep = timedep + self.solver_objs = solver_objs + + self._efuncs = OrderedDict() + self._struct_params = [] + + self._matvec_callback = None + self._formfunc_callback = None + self._formrhs_callback = None + self._struct_callback = None - return obj + self._make_core(injectsolve, objs, solver_objs) + self._main_struct(solver_objs) + self._make_struct_callback(solver_objs, objs) + self._local_struct(solver_objs) + self._efuncs = self._uxreplace_efuncs() @property def efuncs(self): @@ -37,42 +55,38 @@ def efuncs(self): def struct_params(self): return self._struct_params - def make(self, injectsolve, objs, solver_objs): - matvec_callback, formfunc_callback, formrhs_callback = self.make_all( - injectsolve, objs, solver_objs - ) + @property + def filtered_struct_params(self): + return filter_ordered(self.struct_params) - matvec_operation = petsc_call( - 'MatShellSetOperation', [solver_objs['Jac'], 'MATOP_MULT', - MatVecCallback(matvec_callback.name, void, void)] - ) - formfunc_operation = petsc_call( - 'SNESSetFunction', - [solver_objs['snes'], Null, - FormFunctionCallback(formfunc_callback.name, void, void), Null] - ) - runsolve = self.runsolve(solver_objs, objs, formrhs_callback, injectsolve) + @property + def matvec_callback(self): + return self._matvec_callback - return matvec_operation, formfunc_operation, runsolve + @property + def formfunc_callback(self): + return self._formfunc_callback - def make_all(self, injectsolve, objs, solver_objs): - matvec_callback = self.make_matvec(injectsolve, objs, solver_objs) - formfunc_callback = self.make_formfunc(injectsolve, objs, solver_objs) - formrhs_callback = self.make_formrhs(injectsolve, objs, solver_objs) + @property + def formrhs_callback(self): + return self._formrhs_callback - self._efuncs[matvec_callback.name] = matvec_callback - self._efuncs[formfunc_callback.name] = formfunc_callback - self._efuncs[formrhs_callback.name] = formrhs_callback + @property + def struct_callback(self): + return self._struct_callback - return matvec_callback, formfunc_callback, formrhs_callback + def _make_core(self, injectsolve, objs, solver_objs): + self._make_matvec(injectsolve, objs, solver_objs) + self._make_formfunc(injectsolve, objs, solver_objs) + self._make_formrhs(injectsolve, objs, solver_objs) - def make_matvec(self, injectsolve, objs, solver_objs): + def _make_matvec(self, injectsolve, objs, solver_objs): # Compile matvec `eqns` into an IET via recursive compilation irs_matvec, _ = self.rcompile(injectsolve.expr.rhs.matvecs, - options={'mpi': False}, sregistry=SymbolRegistry()) - body_matvec = self.create_matvec_body(injectsolve, - List(body=irs_matvec.uiet.body), - solver_objs, objs) + options={'mpi': False}, sregistry=self.sregistry) + body_matvec = self._create_matvec_body(injectsolve, + List(body=irs_matvec.uiet.body), + solver_objs, objs) matvec_callback = PETScCallable( self.sregistry.make_name(prefix='MyMatShellMult_'), body_matvec, @@ -81,24 +95,25 @@ def make_matvec(self, injectsolve, objs, solver_objs): solver_objs['Jac'], solver_objs['X_global'], solver_objs['Y_global'] ) ) - return matvec_callback + self._matvec_callback = matvec_callback + self._efuncs[matvec_callback.name] = matvec_callback - def create_matvec_body(self, injectsolve, body, solver_objs, objs): - linsolveexpr = injectsolve.expr.rhs + def _create_matvec_body(self, injectsolve, body, solver_objs, objs): + linsolve_expr = injectsolve.expr.rhs - dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + dmda = solver_objs['callbackdm'] - body = uxreplace_time(body, solver_objs) + body = self.timedep.uxreplace_time(body) - struct = build_local_struct(body, 'matvec', liveness='eager') + fields = self._dummy_fields(body, solver_objs) - y_matvec = linsolveexpr.arrays['y_matvec'] - x_matvec = linsolveexpr.arrays['x_matvec'] + y_matvec = linsolve_expr.arrays['y_matvec'] + x_matvec = linsolve_expr.arrays['x_matvec'] mat_get_dm = petsc_call('MatGetDM', [solver_objs['Jac'], Byref(dmda)]) dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(struct._C_symbol)] + 'DMGetApplicationContext', [dmda, Byref(dummyctx._C_symbol)] ) dm_get_local_xvec = petsc_call( @@ -127,7 +142,7 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): ) dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] ) vec_restore_array_y = petsc_call( @@ -146,6 +161,14 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] ]) + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(solver_objs['X_local'])] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(solver_objs['Y_local'])] + ) + # TODO: Some of the calls are placed in the `stacks` argument of the # `CallableBody` to ensure that they precede the `cast` statements. The # 'casts' depend on the calls, so this order is necessary. By doing this, @@ -158,7 +181,9 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): (vec_restore_array_y, vec_restore_array_x, dm_local_to_global_begin, - dm_local_to_global_end) + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) ) stacks = ( @@ -174,8 +199,8 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): ) # Dereference function data in struct - dereference_funcs = [Dereference(i, struct) for i in - struct.fields if isinstance(i.function, AbstractFunction)] + dereference_funcs = [Dereference(i, dummyctx) for i in + fields if isinstance(i.function, AbstractFunction)] matvec_body = CallableBody( List(body=body), @@ -185,47 +210,48 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): ) # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, struct) for i in struct.fields} + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, dummyctx) for i in fields} matvec_body = Uxreplace(subs).visit(matvec_body) - self._struct_params.extend(struct.fields) + self._struct_params.extend(fields) return matvec_body - def make_formfunc(self, injectsolve, objs, solver_objs): + def _make_formfunc(self, injectsolve, objs, solver_objs): # Compile formfunc `eqns` into an IET via recursive compilation irs_formfunc, _ = self.rcompile( injectsolve.expr.rhs.formfuncs, - options={'mpi': False}, sregistry=SymbolRegistry() + options={'mpi': False}, sregistry=self.sregistry ) - body_formfunc = self.create_formfunc_body(injectsolve, - List(body=irs_formfunc.uiet.body), - solver_objs, objs) + body_formfunc = self._create_formfunc_body(injectsolve, + List(body=irs_formfunc.uiet.body), + solver_objs, objs) formfunc_callback = PETScCallable( self.sregistry.make_name(prefix='FormFunction_'), body_formfunc, retval=objs['err'], parameters=(solver_objs['snes'], solver_objs['X_global'], - solver_objs['Y_global'], solver_objs['dummy']) + solver_objs['F_global'], dummyptr) ) - return formfunc_callback + self._formfunc_callback = formfunc_callback + self._efuncs[formfunc_callback.name] = formfunc_callback - def create_formfunc_body(self, injectsolve, body, solver_objs, objs): - linsolveexpr = injectsolve.expr.rhs + def _create_formfunc_body(self, injectsolve, body, solver_objs, objs): + linsolve_expr = injectsolve.expr.rhs - dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + dmda = solver_objs['callbackdm'] - body = uxreplace_time(body, solver_objs) + body = self.timedep.uxreplace_time(body) - struct = build_local_struct(body, 'formfunc', liveness='eager') + fields = self._dummy_fields(body, solver_objs) - y_formfunc = linsolveexpr.arrays['y_formfunc'] - x_formfunc = linsolveexpr.arrays['x_formfunc'] + f_formfunc = linsolve_expr.arrays['f_formfunc'] + x_formfunc = linsolve_expr.arrays['x_formfunc'] snes_get_dm = petsc_call('SNESGetDM', [solver_objs['snes'], Byref(dmda)]) dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(struct._C_symbol)] + 'DMGetApplicationContext', [dmda, Byref(dummyctx._C_symbol)] ) dm_get_local_xvec = petsc_call( @@ -242,11 +268,11 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): ]) dm_get_local_yvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(solver_objs['Y_local'])] + 'DMGetLocalVector', [dmda, Byref(solver_objs['F_local'])] ) vec_get_array_y = petsc_call( - 'VecGetArray', [solver_objs['Y_local'], Byref(y_formfunc._C_symbol)] + 'VecGetArray', [solver_objs['F_local'], Byref(f_formfunc._C_symbol)] ) vec_get_array_x = petsc_call( @@ -254,11 +280,11 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): ) dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] ) vec_restore_array_y = petsc_call( - 'VecRestoreArray', [solver_objs['Y_local'], Byref(y_formfunc._C_symbol)] + 'VecRestoreArray', [solver_objs['F_local'], Byref(f_formfunc._C_symbol)] ) vec_restore_array_x = petsc_call( @@ -266,19 +292,29 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): ) dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ - dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] + dmda, solver_objs['F_local'], 'INSERT_VALUES', solver_objs['F_global'] ]) dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ - dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] + dmda, solver_objs['F_local'], 'INSERT_VALUES', solver_objs['F_global'] ]) + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(solver_objs['X_local'])] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(solver_objs['F_local'])] + ) + body = body._rebuild( body=body.body + (vec_restore_array_y, vec_restore_array_x, dm_local_to_global_begin, - dm_local_to_global_end) + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) ) stacks = ( @@ -294,8 +330,8 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): ) # Dereference function data in struct - dereference_funcs = [Dereference(i, struct) for i in - struct.fields if isinstance(i.function, AbstractFunction)] + dereference_funcs = [Dereference(i, dummyctx) for i in + fields if isinstance(i.function, AbstractFunction)] formfunc_body = CallableBody( List(body=body), @@ -304,20 +340,20 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): retstmt=(Call('PetscFunctionReturn', arguments=[0]),)) # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, struct) for i in struct.fields} + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, dummyctx) for i in fields} formfunc_body = Uxreplace(subs).visit(formfunc_body) - self._struct_params.extend(struct.fields) + self._struct_params.extend(fields) return formfunc_body - def make_formrhs(self, injectsolve, objs, solver_objs): + def _make_formrhs(self, injectsolve, objs, solver_objs): # Compile formrhs `eqns` into an IET via recursive compilation irs_formrhs, _ = self.rcompile(injectsolve.expr.rhs.formrhs, - options={'mpi': False}, sregistry=SymbolRegistry()) - body_formrhs = self.create_formrhs_body(injectsolve, - List(body=irs_formrhs.uiet.body), - solver_objs, objs) + options={'mpi': False}, sregistry=self.sregistry) + body_formrhs = self._create_form_rhs_body(injectsolve, + List(body=irs_formrhs.uiet.body), + solver_objs, objs) formrhs_callback = PETScCallable( self.sregistry.make_name(prefix='FormRHS_'), body_formrhs, retval=objs['err'], @@ -325,32 +361,32 @@ def make_formrhs(self, injectsolve, objs, solver_objs): solver_objs['snes'], solver_objs['b_local'] ) ) + self._formrhs_callback = formrhs_callback + self._efuncs[formrhs_callback.name] = formrhs_callback - return formrhs_callback - - def create_formrhs_body(self, injectsolve, body, solver_objs, objs): - linsolveexpr = injectsolve.expr.rhs + def _create_form_rhs_body(self, injectsolve, body, solver_objs, objs): + linsolve_expr = injectsolve.expr.rhs - dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + dmda = solver_objs['callbackdm'] snes_get_dm = petsc_call('SNESGetDM', [solver_objs['snes'], Byref(dmda)]) - b_arr = linsolveexpr.arrays['b_tmp'] + b_arr = linsolve_expr.arrays['b_tmp'] vec_get_array = petsc_call( 'VecGetArray', [solver_objs['b_local'], Byref(b_arr._C_symbol)] ) dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] ) - body = uxreplace_time(body, solver_objs) + body = self.timedep.uxreplace_time(body) - struct = build_local_struct(body, 'formrhs', liveness='eager') + fields = self._dummy_fields(body, solver_objs) dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(struct._C_symbol)] + 'DMGetApplicationContext', [dmda, Byref(dummyctx._C_symbol)] ) vec_restore_array = petsc_call( @@ -367,8 +403,8 @@ def create_formrhs_body(self, injectsolve, body, solver_objs, objs): ) # Dereference function data in struct - dereference_funcs = [Dereference(i, struct) for i in - struct.fields if isinstance(i.function, AbstractFunction)] + dereference_funcs = [Dereference(i, dummyctx) for i in + fields if isinstance(i.function, AbstractFunction)] formrhs_body = CallableBody( List(body=[body]), @@ -378,34 +414,325 @@ def create_formrhs_body(self, injectsolve, body, solver_objs, objs): ) # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, struct) for - i in struct.fields if not isinstance(i.function, AbstractFunction)} - + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, dummyctx) for + i in fields if not isinstance(i.function, AbstractFunction)} formrhs_body = Uxreplace(subs).visit(formrhs_body) - self._struct_params.extend(struct.fields) + self._struct_params.extend(fields) return formrhs_body - def runsolve(self, solver_objs, objs, rhs_callback, injectsolve): + def _local_struct(self, solver_objs): + """ + This is the struct used within callback functions, + usually accessed via DMGetApplicationContext. + """ + solver_objs['localctx'] = petsc_struct( + dummyctx.name, + self.filtered_struct_params, + solver_objs['Jac'].name+'_ctx', + liveness='eager' + ) + + def _main_struct(self, solver_objs): + """ + This is the struct initialised inside the main kernel and + attached to the DM via DMSetApplicationContext. + """ + solver_objs['mainctx'] = petsc_struct( + self.sregistry.make_name(prefix='ctx'), + self.filtered_struct_params, + solver_objs['Jac'].name+'_ctx' + ) + + def _make_struct_callback(self, solver_objs, objs): + mainctx = solver_objs['mainctx'] + body = [ + DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol) + for i in mainctx.callback_fields + ] + struct_callback_body = CallableBody( + List(body=body), init=(petsc_func_begin_user,), + retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) + ) + struct_callback = Callable( + self.sregistry.make_name(prefix='PopulateMatContext_'), + struct_callback_body, objs['err'], + parameters=[mainctx] + ) + self._efuncs[struct_callback.name] = struct_callback + self._struct_callback = struct_callback + + def _dummy_fields(self, iet, solver_objs): + # Place all context data required by the shell routines into a struct + fields = [f.function for f in FindSymbols('basics').visit(iet)] + fields = [f for f in fields if not isinstance(f.function, (PETScArray, Temp))] + fields = [ + f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo)) + ] + return fields + + def _uxreplace_efuncs(self): + mapper = {} + visitor = Uxreplace({dummyctx: self.solver_objs['localctx']}) + for k, v in self._efuncs.items(): + mapper.update({k: visitor.visit(v)}) + return mapper + + +class BaseObjectBuilder: + """ + A base class for constructing objects needed for a PETSc solver. + Designed to be extended by subclasses, which can override the `_extend_build` + method to support specific use cases. + """ + + def __init__(self, injectsolve, sregistry=None, **kwargs): + self.sregistry = sregistry + self.solver_objs = self._build(injectsolve) + + def _build(self, injectsolve): + """ + Constructs the core dictionary of solver objects and allows + subclasses to extend or modify it via `_extend_build`. + + Returns: + dict: A dictionary containing the following objects: + - 'Jac' (Mat): A matrix representing the jacobian. + - 'x_global' (GlobalVec): The global solution vector. + - 'x_local' (LocalVec): The local solution vector. + - 'b_global': (GlobalVec) Global RHS vector `b`, where `F(x) = b`. + - 'b_local': (LocalVec) Local RHS vector `b`, where `F(x) = b`. + - 'ksp': (KSP) Krylov solver object that manages the linear solver. + - 'pc': (PC) Preconditioner object. + - 'snes': (SNES) Nonlinear solver object. + - 'F_global': (GlobalVec) Global residual vector `F`, where `F(x) = b`. + - 'F_local': (LocalVec) Local residual vector `F`, where `F(x) = b`. + - 'Y_global': (GlobalVector) The output vector populated by the + matrix-free `MyMatShellMult` callback function. + - 'Y_local': (LocalVector) The output vector populated by the matrix-free + `MyMatShellMult` callback function. + - 'X_global': (GlobalVec) Current guess for the solution, + required by the FormFunction callback. + - 'X_local': (LocalVec) Current guess for the solution, + required by the FormFunction callback. + - 'localsize' (PetscInt): The local length of the solution vector. + - 'start_ptr' (StartPtr): A pointer to the beginning of the solution array + that will be updated at each time step. + - 'dmda' (DM): The DMDA object associated with this solve, linked to + the SNES object via `SNESSetDM`. + - 'callbackdm' (CallbackDM): The DM object accessed within callback + functions via `SNESGetDM`. + """ target = injectsolve.expr.rhs.target + sreg = self.sregistry + base_dict = { + 'Jac': Mat(sreg.make_name(prefix='J_')), + 'x_global': GlobalVec(sreg.make_name(prefix='x_global_')), + 'x_local': LocalVec(sreg.make_name(prefix='x_local_'), liveness='eager'), + 'b_global': GlobalVec(sreg.make_name(prefix='b_global_')), + 'b_local': LocalVec(sreg.make_name(prefix='b_local_')), + 'ksp': KSP(sreg.make_name(prefix='ksp_')), + 'pc': PC(sreg.make_name(prefix='pc_')), + 'snes': SNES(sreg.make_name(prefix='snes_')), + 'F_global': GlobalVec(sreg.make_name(prefix='F_global_')), + 'F_local': LocalVec(sreg.make_name(prefix='F_local_'), liveness='eager'), + 'Y_global': GlobalVec(sreg.make_name(prefix='Y_global_')), + 'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'), + 'X_global': GlobalVec(sreg.make_name(prefix='X_global_')), + 'X_local': LocalVec(sreg.make_name(prefix='X_local_'), liveness='eager'), + 'localsize': PetscInt(sreg.make_name(prefix='localsize_')), + 'start_ptr': StartPtr(sreg.make_name(prefix='start_ptr_'), target.dtype), + 'dmda': DM(sreg.make_name(prefix='da_'), liveness='eager', + stencil_width=target.space_order), + 'callbackdm': CallbackDM(sreg.make_name(prefix='dm_'), + liveness='eager', stencil_width=target.space_order), + } + return self._extend_build(base_dict, injectsolve) + + def _extend_build(self, base_dict, injectsolve): + """ + Subclasses can override this method to extend or modify the + base dictionary of solver objects. + """ + return base_dict + - dmda = objs['da_so_%s' % target.space_order] +class BaseSetup: + def __init__(self, solver_objs, objs, injectsolve, cbbuilder): + self.calls = self._setup(solver_objs, objs, injectsolve, cbbuilder) + + def _setup(self, solver_objs, objs, injectsolve, cbbuilder): + dmda = solver_objs['dmda'] + + solver_params = injectsolve.expr.rhs.solver_parameters + + snes_create = petsc_call('SNESCreate', [objs['comm'], Byref(solver_objs['snes'])]) + + snes_set_dm = petsc_call('SNESSetDM', [solver_objs['snes'], dmda]) + + create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(solver_objs['Jac'])]) + + # NOTE: Assuming all solves are linear for now. + snes_set_type = petsc_call('SNESSetType', [solver_objs['snes'], 'SNESKSPONLY']) + + snes_set_jac = petsc_call( + 'SNESSetJacobian', [solver_objs['snes'], solver_objs['Jac'], + solver_objs['Jac'], 'MatMFFDComputeJacobian', Null] + ) + + global_x = petsc_call('DMCreateGlobalVector', + [dmda, Byref(solver_objs['x_global'])]) + + global_b = petsc_call('DMCreateGlobalVector', + [dmda, Byref(solver_objs['b_global'])]) + + local_b = petsc_call('DMCreateLocalVector', + [dmda, Byref(solver_objs['b_local'])]) + + snes_get_ksp = petsc_call('SNESGetKSP', + [solver_objs['snes'], Byref(solver_objs['ksp'])]) + + ksp_set_tols = petsc_call( + 'KSPSetTolerances', [solver_objs['ksp'], solver_params['ksp_rtol'], + solver_params['ksp_atol'], solver_params['ksp_divtol'], + solver_params['ksp_max_it']] + ) + + ksp_set_type = petsc_call( + 'KSPSetType', [solver_objs['ksp'], solver_mapper[solver_params['ksp_type']]] + ) + + ksp_get_pc = petsc_call( + 'KSPGetPC', [solver_objs['ksp'], Byref(solver_objs['pc'])] + ) + + # Even though the default will be jacobi, set to PCNONE for now + pc_set_type = petsc_call('PCSetType', [solver_objs['pc'], 'PCNONE']) + + ksp_set_from_ops = petsc_call('KSPSetFromOptions', [solver_objs['ksp']]) + + matvec_operation = petsc_call( + 'MatShellSetOperation', + [solver_objs['Jac'], 'MATOP_MULT', + MatVecCallback(cbbuilder.matvec_callback.name, void, void)] + ) + + formfunc_operation = petsc_call( + 'SNESSetFunction', + [solver_objs['snes'], Null, + FormFunctionCallback(cbbuilder.formfunc_callback.name, void, void), Null] + ) + + dmda_calls = self._create_dmda_calls(dmda, objs) + + mainctx = solver_objs['mainctx'] + + call_struct_callback = petsc_call( + cbbuilder.struct_callback.name, [Byref(mainctx)] + ) + calls_set_app_ctx = [ + petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)]) + ] + calls = [call_struct_callback] + calls_set_app_ctx + [BlankLine] + + base_setup = dmda_calls + ( + snes_create, + snes_set_dm, + create_matrix, + snes_set_jac, + snes_set_type, + global_x, + global_b, + local_b, + snes_get_ksp, + ksp_set_tols, + ksp_set_type, + ksp_get_pc, + pc_set_type, + ksp_set_from_ops, + matvec_operation, + formfunc_operation, + ) + tuple(calls) + + extended_setup = self._extend_setup(solver_objs, objs, injectsolve, cbbuilder) + return base_setup + tuple(extended_setup) + + def _extend_setup(self, solver_objs, objs, injectsolve, cbbuilder): + """ + Hook for subclasses to add additional setup calls. + """ + return [] + + def _create_dmda_calls(self, dmda, objs): + dmda_create = self._create_dmda(dmda, objs) + dm_setup = petsc_call('DMSetUp', [dmda]) + dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL']) + return dmda_create, dm_setup, dm_mat_type + + def _create_dmda(self, dmda, objs): + grid = objs['grid'] + + nspace_dims = len(grid.dimensions) + + # MPI communicator + args = [objs['comm']] + + # Type of ghost nodes + args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(nspace_dims)]) + + # Stencil type + if nspace_dims > 1: + args.append('DMDA_STENCIL_BOX') + + # Global dimensions + args.extend(list(grid.shape)[::-1]) + # No.of processors in each dimension + if nspace_dims > 1: + args.extend(list(grid.distributor.topology)[::-1]) + + # Number of degrees of freedom per node + args.append(1) + # "Stencil width" -> size of overlap + args.append(dmda.stencil_width) + args.extend([Null]*nspace_dims) + + # The distributed array object + args.append(Byref(dmda)) + + # The PETSc call used to create the DMDA + dmda = petsc_call('DMDACreate%sd' % nspace_dims, args) + + return dmda + + +class Solver: + def __init__(self, solver_objs, objs, injectsolve, iters, cbbuilder, + timedep=None, **kwargs): + self.timedep = timedep + self.calls = self._execute_solve(solver_objs, objs, injectsolve, iters, cbbuilder) + self.spatial_body = self._spatial_loop_nest(iters, injectsolve) + + space_iter, = self.spatial_body + self.mapper = {space_iter: self.calls} + + def _execute_solve(self, solver_objs, objs, injectsolve, iters, cbbuilder): + """ + Assigns the required time iterators to the struct and executes + the necessary calls to execute the SNES solver. + """ + struct_assignment = self.timedep.assign_time_iters(solver_objs['mainctx']) + + rhs_callback = cbbuilder.formrhs_callback + + dmda = solver_objs['dmda'] rhs_call = petsc_call(rhs_callback.name, list(rhs_callback.parameters)) local_x = petsc_call('DMCreateLocalVector', [dmda, Byref(solver_objs['x_local'])]) - if any(i.is_Time for i in target.dimensions): - vec_replace_array = time_dep_replace( - injectsolve, solver_objs, objs, self.sregistry - ) - else: - field_from_ptr = FieldFromPointer(target._C_field_data, target._C_symbol) - vec_replace_array = (petsc_call( - 'VecReplaceArray', [solver_objs['x_local'], field_from_ptr] - ),) + vec_replace_array = self.timedep.replace_array(solver_objs) dm_local_to_global_x = petsc_call( 'DMLocalToGlobal', [dmda, solver_objs['x_local'], 'INSERT_VALUES', @@ -425,7 +752,7 @@ def runsolve(self, solver_objs, objs, rhs_callback, injectsolve): dmda, solver_objs['x_global'], 'INSERT_VALUES', solver_objs['x_local']] ) - return ( + run_solver_calls = (struct_assignment,) + ( rhs_call, local_x ) + vec_replace_array + ( @@ -435,90 +762,241 @@ def runsolve(self, solver_objs, objs, rhs_callback, injectsolve): dm_global_to_local_x, BlankLine, ) + return List(body=run_solver_calls) - def make_main_struct(self, unique_dmdas, objs): - struct_main = petsc_struct('ctx', filter_ordered(self.struct_params)) - struct_callback = self.generate_struct_callback(struct_main, objs) - call_struct_callback = petsc_call(struct_callback.name, [Byref(struct_main)]) - calls_set_app_ctx = [ - petsc_call('DMSetApplicationContext', [i, Byref(struct_main)]) - for i in unique_dmdas - ] - calls = [call_struct_callback] + calls_set_app_ctx + def _spatial_loop_nest(self, iters, injectsolve): + spatial_body = [] + for tree in retrieve_iteration_tree(iters[0]): + root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0] + if injectsolve in FindNodes(InjectSolveDummy).visit(root): + spatial_body.append(root) + return spatial_body - self._efuncs[struct_callback.name] = struct_callback - return struct_main, calls - def generate_struct_callback(self, struct, objs): - body = [ - DummyExpr(FieldFromPointer(i._C_symbol, struct), i._C_symbol) - for i in struct.fields if i not in struct.time_dim_fields - ] - struct_callback_body = CallableBody( - List(body=body), init=tuple([petsc_func_begin_user]), - retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) - ) - struct_callback = Callable( - 'PopulateMatContext', struct_callback_body, objs['err'], - parameters=[struct] - ) - return struct_callback +class NonTimeDependent: + def __init__(self, injectsolve, iters, **kwargs): + self.injectsolve = injectsolve + self.iters = iters + self.kwargs = kwargs + self.origin_to_moddim = self._origin_to_moddim_mapper(iters) + self.time_idx_to_symb = injectsolve.expr.rhs.time_mapper + + @property + def is_target_time(self): + return False + + @property + def target(self): + return self.injectsolve.expr.rhs.target + + def _origin_to_moddim_mapper(self, iters): + return {} + + def uxreplace_time(self, body): + return body + def replace_array(self, solver_objs): + """ + VecReplaceArray() is a PETSc function that allows replacing the array + of a `Vec` with a user provided array. + https://petsc.org/release/manualpages/Vec/VecReplaceArray/ -def build_local_struct(iet, name, liveness): - # Place all context data required by the shell routines into a struct - fields = [ - i.function for i in FindSymbols('basics').visit(iet) - if not isinstance(i.function, (PETScArray, Temp)) - and not (i.is_Dimension and not isinstance(i, (TimeDimension, ModuloDimension))) - ] - return petsc_struct(name, fields, liveness) + This function is used to replace the array of the PETSc solution `Vec` + with the array from the `Function` object representing the target. + Examples + -------- + >>> self.target + f1(x, y) + >>> call = replace_array(solver_objs) + >>> print(call) + PetscCall(VecReplaceArray(x_local_0,f1_vec->data)); + """ + field_from_ptr = FieldFromPointer( + self.target.function._C_field_data, self.target.function._C_symbol + ) + vec_replace_array = (petsc_call( + 'VecReplaceArray', [solver_objs['x_local'], field_from_ptr] + ),) + return vec_replace_array -def time_dep_replace(injectsolve, solver_objs, objs, sregistry): - target = injectsolve.expr.lhs - target_time = [ - i for i, d in zip(target.indices, target.dimensions) if d.is_Time - ] - assert len(target_time) == 1 - target_time = target_time.pop() + def assign_time_iters(self, struct): + return [] - start_ptr = solver_objs['start_ptr'] - vec_get_size = petsc_call( - 'VecGetSize', [solver_objs['x_local'], Byref(solver_objs['localsize'])] - ) +class TimeDependent(NonTimeDependent): + """ + A class for managing time-dependent solvers. + + This includes scenarios where the target is not directly a `TimeFunction`, + but depends on other functions that are. + + Outline of time loop abstraction with PETSc: + + - At PETScSolve, time indices are replaced with temporary `Symbol` objects + via a mapper (e.g., {t: tau0, t + dt: tau1}) to prevent the time loop + from being generated in the callback functions. These callbacks, needed + for each `SNESSolve` at every time step, don't require the time loop, but + may still need access to data from other time steps. + - All `Function` objects are passed through the initial lowering via the + `LinearSolveExpr` object, ensuring the correct time loop is generated + in the main kernel. + - Another mapper is created based on the modulo dimensions + generated by the `LinearSolveExpr` object in the main kernel + (e.g., {time: time, t: t0, t + 1: t1}). + - These two mappers are used to generate a final mapper `symb_to_moddim` + (e.g. {tau0: t0, tau1: t1}) which is used at the IET level to + replace the temporary `Symbol` objects in the callback functions with + the correct modulo dimensions. + - Modulo dimensions are updated in the matrix context struct at each time + step and can be accessed in the callback functions where needed. + """ + @property + def is_target_time(self): + return any(i.is_Time for i in self.target.dimensions) - field_from_ptr = FieldFromPointer( - target.function._C_field_data, target.function._C_symbol - ) + @property + def time_spacing(self): + return self.target.grid.stepping_dim.spacing - expr = DummyExpr( - start_ptr, cast_mapper[(target.dtype, '*')](field_from_ptr) + - Mul(target_time, solver_objs['localsize']), init=True - ) + @property + def target_time(self): + target_time = [ + i for i, d in zip(self.target.indices, self.target.dimensions) + if d.is_Time + ] + assert len(target_time) == 1 + target_time = target_time.pop() + return target_time - vec_replace_array = petsc_call('VecReplaceArray', [solver_objs['x_local'], start_ptr]) - return (vec_get_size, expr, vec_replace_array) + @property + def symb_to_moddim(self): + """ + Maps temporary `Symbol` objects created during `PETScSolve` to their + corresponding modulo dimensions (e.g. creates {tau0: t0, tau1: t1}). + """ + mapper = { + v: k.xreplace({self.time_spacing: 1, -self.time_spacing: -1}) + for k, v in self.time_idx_to_symb.items() + } + return {symb: self.origin_to_moddim[mapper[symb]] for symb in mapper} + + def uxreplace_time(self, body): + return Uxreplace(self.symb_to_moddim).visit(body) + + def _origin_to_moddim_mapper(self, iters): + """ + Creates a mapper of the origin of the time dimensions to their corresponding + modulo dimensions from a list of `Iteration` objects. + + Examples + -------- + >>> iters + (, + ) + >>> _origin_to_moddim_mapper(iters) + {time: time, t: t0, t + 1: t1} + """ + time_iter = [i for i in iters if any(d.is_Time for d in i.dimensions)] + mapper = {} + + if not time_iter: + return mapper + + for i in time_iter: + for d in i.dimensions: + if d.is_Modulo: + mapper[d.origin] = d + elif d.is_Time: + mapper[d] = d + return mapper + + def replace_array(self, solver_objs): + """ + In the case that the actual target is time-dependent e.g a `TimeFunction`, + a pointer to the first element in the array that will be updated during + the time step is passed to VecReplaceArray(). + + Examples + -------- + >>> self.target + f1(time + dt, x, y) + >>> calls = replace_array(solver_objs) + >>> print(List(body=calls)) + PetscCall(VecGetSize(x_local_0,&(localsize_0))); + float * start_ptr_0 = (time + 1)*localsize_0 + (float*)(f1_vec->data); + PetscCall(VecReplaceArray(x_local_0,start_ptr_0)); + + >>> self.target + f1(t + dt, x, y) + >>> calls = replace_array(solver_objs) + >>> print(List(body=calls)) + PetscCall(VecGetSize(x_local_0,&(localsize_0))); + float * start_ptr_0 = t1*localsize_0 + (float*)(f1_vec->data); + """ + if self.is_target_time: + mapper = {self.time_spacing: 1, -self.time_spacing: -1} + target_time = self.target_time.xreplace(mapper) + + try: + target_time = self.origin_to_moddim[target_time] + except KeyError: + pass + + start_ptr = solver_objs['start_ptr'] + + vec_get_size = petsc_call( + 'VecGetSize', [solver_objs['x_local'], Byref(solver_objs['localsize'])] + ) + field_from_ptr = FieldFromPointer( + self.target.function._C_field_data, self.target.function._C_symbol + ) -def uxreplace_time(body, solver_objs): - # TODO: Potentially introduce a TimeIteration abstraction to simplify - # all the time processing that is done (searches, replacements, ...) - # "manually" via free functions - time_spacing = solver_objs['target'].grid.stepping_dim.spacing - true_dims = solver_objs['true_dims'] + expr = DummyExpr( + start_ptr, cast_mapper[(self.target.dtype, '*')](field_from_ptr) + + Mul(target_time, solver_objs['localsize']), init=True + ) - time_mapper = { - v: k.xreplace({time_spacing: 1, -time_spacing: -1}) - for k, v in solver_objs['time_mapper'].items() - } - subs = {symb: true_dims[time_mapper[symb]] for symb in time_mapper} - return Uxreplace(subs).visit(body) + vec_replace_array = petsc_call( + 'VecReplaceArray', [solver_objs['x_local'], start_ptr] + ) + return (vec_get_size, expr, vec_replace_array) + else: + return super().replace_array(solver_objs) + + def assign_time_iters(self, struct): + """ + Assign required time iterators to the struct. + These iterators are updated at each timestep in the main kernel + for use in callback functions. + + Examples + -------- + >>> struct + ctx + >>> struct.fields + [h_x, x_M, x_m, f1(t, x), t0, t1] + >>> assigned = assign_time_iters(struct) + >>> print(assigned[0]) + ctx.t0 = t0; + >>> print(assigned[1]) + ctx.t1 = t1; + """ + to_assign = [ + f for f in struct.fields if (f.is_Dimension and (f.is_Time or f.is_Modulo)) + ] + time_iter_assignments = [ + DummyExpr(FieldFromComposite(field, struct), field) + for field in to_assign + ] + return time_iter_assignments Null = Macro('NULL') void = 'void' +dummyctx = Symbol('lctx') +dummyptr = DummyArg('dummy') # TODO: Don't use c.Line here? diff --git a/devito/petsc/iet/utils.py b/devito/petsc/iet/utils.py index a7855fbb36..adcf709eab 100644 --- a/devito/petsc/iet/utils.py +++ b/devito/petsc/iet/utils.py @@ -10,10 +10,10 @@ def petsc_call_mpi(specific_call, call_args): return PETScCall('PetscCallMPI', [PETScCall(specific_call, arguments=call_args)]) -def petsc_struct(name, fields, liveness='lazy'): +def petsc_struct(name, fields, pname, liveness='lazy'): # TODO: Fix this circular import from devito.petsc.types.object import PETScStruct - return PETScStruct(name=name, pname='MatContext', + return PETScStruct(name=name, pname=pname, fields=fields, liveness=liveness) diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py index c5df6b859f..4f16ded1f3 100644 --- a/devito/petsc/solve.py +++ b/devito/petsc/solve.py @@ -9,23 +9,22 @@ from devito.operations.solve import eval_time_derivatives from devito.symbolics import retrieve_functions from devito.tools import as_tuple -from devito.petsc.types import LinearSolveExpr, PETScArray +from devito.petsc.types import LinearSolveExpr, PETScArray, DMDALocalInfo __all__ = ['PETScSolve'] def PETScSolve(eqns, target, solver_parameters=None, **kwargs): - prefixes = ['y_matvec', 'x_matvec', 'y_formfunc', 'x_formfunc', 'b_tmp'] + prefixes = ['y_matvec', 'x_matvec', 'f_formfunc', 'x_formfunc', 'b_tmp'] + + localinfo = DMDALocalInfo(name='info', liveness='eager') arrays = { p: PETScArray(name='%s_%s' % (p, target.name), - dtype=target.dtype, - dimensions=target.space_dimensions, - shape=target.grid.shape, - liveness='eager', - halo=[target.halo[d] for d in target.space_dimensions], - space_order=target.space_order) + target=target, + liveness='eager', + localinfo=localinfo) for p in prefixes } @@ -47,7 +46,7 @@ def PETScSolve(eqns, target, solver_parameters=None, **kwargs): )) formfuncs.append(Eq( - arrays['y_formfunc'], + arrays['f_formfunc'], F_target.subs(targets_to_arrays(arrays['x_formfunc'], targets)), subdomain=eq.subdomain )) @@ -60,8 +59,9 @@ def PETScSolve(eqns, target, solver_parameters=None, **kwargs): funcs = retrieve_functions(eqns) time_mapper = generate_time_mapper(funcs) + matvecs, formfuncs, formrhs = ( - [eq.subs(time_mapper) for eq in lst] for lst in (matvecs, formfuncs, formrhs) + [eq.xreplace(time_mapper) for eq in lst] for lst in (matvecs, formfuncs, formrhs) ) # Placeholder equation for inserting calls to the solver and generating # correct time loop etc @@ -74,7 +74,8 @@ def PETScSolve(eqns, target, solver_parameters=None, **kwargs): formrhs=formrhs, arrays=arrays, time_mapper=time_mapper, - ), subdomain=eq.subdomain) + localinfo=localinfo + )) return [inject_solve] @@ -211,12 +212,24 @@ def generate_time_mapper(funcs): Replace time indices with `Symbols` in equations used within PETSc callback functions. These symbols are Uxreplaced at the IET level to align with the `TimeDimension` and `ModuloDimension` objects - present in the inital lowering. + present in the initial lowering. NOTE: All functions used in PETSc callback functions are attached to the `LinearSolveExpr` object, which is passed through the initial lowering (and subsequently dropped and replaced with calls to run the solver). Therefore, the appropriate time loop will always be correctly generated inside the main kernel. + + Examples + -------- + >>> funcs = [ + >>> f1(t + dt, x, y), + >>> g1(t + dt, x, y), + >>> g2(t, x, y), + >>> f1(t, x, y) + >>> ] + >>> generate_time_mapper(funcs) + {t + dt: tau0, t: tau1} + """ time_indices = list({ i if isinstance(d, SteppingDimension) else d diff --git a/devito/petsc/types/array.py b/devito/petsc/types/array.py index a150ea3247..38ac3bb9f3 100644 --- a/devito/petsc/types/array.py +++ b/devito/petsc/types/array.py @@ -1,5 +1,4 @@ from functools import cached_property -import numpy as np from ctypes import POINTER from devito.types.utils import DimensionTuple @@ -7,18 +6,21 @@ from devito.finite_differences import Differentiable from devito.types.basic import AbstractFunction from devito.finite_differences.tools import fd_weights_registry -from devito.tools import dtype_to_ctype +from devito.tools import dtype_to_ctype, as_tuple from devito.symbolics import FieldFromComposite -from .object import DM - class PETScArray(ArrayBasic, Differentiable): """ PETScArrays are generated by the compiler only and represent a customised variant of ArrayBasic. - Differentiable enables compatability with standard Function objects, + Differentiable enables compatibility with standard Function objects, allowing for the use of the `subs` method. + + PETScArray objects represent vector objects within PETSc. + They correspond to the spatial domain of a Function-like object + provided by the user, which is passed to PETScSolve as the target. + TODO: Potentially re-evaluate and separate into PETScFunction(Differentiable) and then PETScArray(ArrayBasic). """ @@ -29,11 +31,14 @@ class PETScArray(ArrayBasic, Differentiable): _default_fd = 'taylor' __rkwargs__ = (AbstractFunction.__rkwargs__ + - ('dimensions', 'shape', 'liveness', 'coefficients', - 'space_order')) + ('target', 'liveness', 'coefficients', 'localinfo')) def __init_finalize__(self, *args, **kwargs): + self._target = kwargs.get('target') + self._ndim = kwargs['ndim'] = len(self._target.space_dimensions) + self._dimensions = kwargs['dimensions'] = self._target.space_dimensions + super().__init_finalize__(*args, **kwargs) # Symbolic (finite difference) coefficients @@ -41,12 +46,38 @@ def __init_finalize__(self, *args, **kwargs): if self._coefficients not in fd_weights_registry: raise ValueError("coefficients must be one of %s" " not %s" % (str(fd_weights_registry), self._coefficients)) - self._shape = kwargs.get('shape') - self._space_order = kwargs.get('space_order', 1) + + self._localinfo = kwargs.get('localinfo', None) + + @property + def ndim(self): + return self._ndim @classmethod def __dtype_setup__(cls, **kwargs): - return kwargs.get('dtype', np.float32) + return kwargs['target'].dtype + + @classmethod + def __indices_setup__(cls, *args, **kwargs): + dimensions = kwargs['target'].space_dimensions + if args: + indices = args + else: + indices = dimensions + return as_tuple(dimensions), as_tuple(indices) + + def __halo_setup__(self, **kwargs): + target = kwargs['target'] + halo = [target.halo[d] for d in target.space_dimensions] + return DimensionTuple(*halo, getters=target.space_dimensions) + + @property + def dimensions(self): + return self._dimensions + + @property + def target(self): + return self._target @property def coefficients(self): @@ -55,50 +86,27 @@ def coefficients(self): @property def shape(self): - return self._shape + return self.target.grid.shape @property def space_order(self): - return self._space_order + return self.target.space_order + + @property + def localinfo(self): + return self._localinfo @cached_property def _shape_with_inhalo(self): - """ - Shape of the domain+inhalo region. The inhalo region comprises the - outhalo as well as any additional "ghost" layers for MPI halo - exchanges. Data in the inhalo region are exchanged when running - Operators to maintain consistent values as in sequential runs. - - Notes - ----- - Typically, this property won't be used in user code, but it may come - in handy for testing or debugging - """ - return tuple(j + i + k for i, (j, k) in zip(self.shape, self._halo)) + return self.target.shape_with_inhalo @cached_property def shape_allocated(self): - """ - Shape of the allocated data of the Function type object from which - this PETScArray was derived. It includes the domain and inhalo regions, - as well as any additional padding surrounding the halo. - - Notes - ----- - In an MPI context, this is the *local* with_halo region shape. - """ - return DimensionTuple(*[j + i + k for i, (j, k) in zip(self._shape_with_inhalo, - self._padding)], - getters=self.dimensions) + return self.target.shape_allocated @cached_property def _C_ctype(self): - # NOTE: Reverting to using float/double instead of PetscScalar for - # simplicity when opt='advanced'. Otherwise, Temp objects must also - # be converted to PetscScalar. Additional tests are needed to - # ensure this approach is fine. Previously, issues arose from - # mismatches between precision of Function objects in Devito and the - # precision of the PETSc configuration. + # TODO: Switch to using PetscScalar instead of float/double # TODO: Use cat $PETSC_DIR/$PETSC_ARCH/lib/petsc/conf/petscvariables # | grep -E "PETSC_(SCALAR|PRECISION)" to determine the precision of # the user's PETSc configuration. @@ -107,11 +115,6 @@ def _C_ctype(self): @property def symbolic_shape(self): field_from_composites = [ - FieldFromComposite('g%sm' % d.name, self.dmda.info) for d in self.dimensions] + FieldFromComposite('g%sm' % d.name, self.localinfo) for d in self.dimensions] # Reverse it since DMDA is setup backwards to Devito dimensions. return DimensionTuple(*field_from_composites[::-1], getters=self.dimensions) - - @cached_property - def dmda(self): - name = 'da_so_%s' % self.space_order - return DM(name=name, liveness='eager', stencil_width=self.space_order) diff --git a/devito/petsc/types/object.py b/devito/petsc/types/object.py index 9f8cbe4cbb..1bcfb3a6cf 100644 --- a/devito/petsc/types/object.py +++ b/devito/petsc/types/object.py @@ -9,7 +9,9 @@ class DM(LocalObject): """ - PETSc Data Management object (DM). + PETSc Data Management object (DM). This is the primary DM instance + created within the main kernel and linked to the SNES + solver using `SNESSetDM`. """ dtype = CustomDtype('DM') @@ -21,10 +23,6 @@ def __init__(self, *args, stencil_width=None, **kwargs): def stencil_width(self): return self._stencil_width - @property - def info(self): - return DMDALocalInfo(name='%s_info' % self.name, liveness='eager') - @property def _C_free(self): return petsc_call('DMDestroy', [Byref(self.function)]) @@ -34,6 +32,22 @@ def _C_free_priority(self): return 3 +class CallbackDM(LocalObject): + """ + PETSc Data Management object (DM). This is the DM instance + accessed within the callback functions via `SNESGetDM`. + """ + dtype = CustomDtype('DM') + + def __init__(self, *args, stencil_width=None, **kwargs): + super().__init__(*args, **kwargs) + self._stencil_width = stencil_width + + @property + def stencil_width(self): + return self._stencil_width + + class Mat(LocalObject): """ PETSc Matrix object (Mat). @@ -51,14 +65,18 @@ def _C_free_priority(self): class LocalVec(LocalObject): """ - PETSc Vector object (Vec). + PETSc local vector object (Vec). + A local vector has ghost locations that contain values that are + owned by other MPI ranks. """ dtype = CustomDtype('Vec') class GlobalVec(LocalObject): """ - PETSc Vector object (Vec). + PETSc global vector object (Vec). + A global vector is a parallel vector that has no duplicate values + between MPI ranks. A global vector has no ghost locations. """ dtype = CustomDtype('Vec') @@ -142,6 +160,10 @@ class PetscErrorCode(LocalObject): class DummyArg(LocalObject): + """ + A void pointer used to satisfy the function + signature of the `FormFunction` callback. + """ dtype = CustomDtype('void', modifier='*') @@ -160,9 +182,21 @@ def fields(self): @property def time_dim_fields(self): + """ + Fields within the struct that are updated during the time loop. + These are not set in the `PopulateMatContext` callback. + """ return [f for f in self.fields if isinstance(f, (ModuloDimension, TimeDimension))] + @property + def callback_fields(self): + """ + Fields within the struct that are initialized in the `PopulateMatContext` + callback. These fields are not updated in the time loop. + """ + return [f for f in self.fields if f not in self.time_dim_fields] + @property def _C_ctype(self): return POINTER(self.dtype) if self.liveness == \ diff --git a/devito/petsc/types/types.py b/devito/petsc/types/types.py index eda2fa40d4..1a4a778c9e 100644 --- a/devito/petsc/types/types.py +++ b/devito/petsc/types/types.py @@ -4,23 +4,53 @@ class LinearSolveExpr(sympy.Function, Reconstructable): + """ + A symbolic expression passed through the Operator, containing the metadata + needed to execute a linear solver. Linear problems are handled with + `SNESSetType(snes, KSPONLY)`, enabling a unified interface for both + linear and nonlinear solvers. + + # TODO: extend this + defaults: + - 'ksp_type': String with the name of the PETSc Krylov method. + Default is 'gmres' (Generalized Minimal Residual Method). + https://petsc.org/main/manualpages/KSP/KSPType/ + + - 'pc_type': String with the name of the PETSc preconditioner. + Default is 'jacobi' (i.e diagonal scaling preconditioning). + https://petsc.org/main/manualpages/PC/PCType/ + + KSP tolerances: + https://petsc.org/release/manualpages/KSP/KSPSetTolerances/ + + - 'ksp_rtol': Relative convergence tolerance. Default + is 1e-5. + - 'ksp_atol': Absolute convergence for tolerance. Default + is 1e-50. + - 'ksp_divtol': Divergence tolerance, amount residual norm can + increase before `KSPConvergedDefault()` concludes + that the method is diverging. Default is 1e5. + - 'ksp_max_it': Maximum number of iterations to use. Default + is 1e4. + """ __rargs__ = ('expr',) __rkwargs__ = ('target', 'solver_parameters', 'matvecs', - 'formfuncs', 'formrhs', 'arrays', 'time_mapper') + 'formfuncs', 'formrhs', 'arrays', 'time_mapper', + 'localinfo') defaults = { 'ksp_type': 'gmres', 'pc_type': 'jacobi', - 'ksp_rtol': 1e-7, # Relative tolerance + 'ksp_rtol': 1e-5, # Relative tolerance 'ksp_atol': 1e-50, # Absolute tolerance - 'ksp_divtol': 1e4, # Divergence tolerance - 'ksp_max_it': 10000 # Maximum iterations + 'ksp_divtol': 1e5, # Divergence tolerance + 'ksp_max_it': 1e4 # Maximum iterations } def __new__(cls, expr, target=None, solver_parameters=None, matvecs=None, formfuncs=None, formrhs=None, - arrays=None, time_mapper=None, **kwargs): + arrays=None, time_mapper=None, localinfo=None, **kwargs): if solver_parameters is None: solver_parameters = cls.defaults @@ -39,6 +69,7 @@ def __new__(cls, expr, target=None, solver_parameters=None, obj._formrhs = formrhs obj._arrays = arrays obj._time_mapper = time_mapper + obj._localinfo = localinfo return obj def __repr__(self): @@ -89,6 +120,10 @@ def arrays(self): def time_mapper(self): return self._time_mapper + @property + def localinfo(self): + return self._localinfo + @classmethod def eval(cls, *args): return None diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 5b13262ded..9db65d8bb2 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -252,6 +252,10 @@ def __str__(self): def field(self): return self.call + @property + def dtype(self): + return self.field.dtype + __repr__ = __str__ diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 18e2623764..123a8c46e4 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -295,10 +295,10 @@ def sympy_dtype(expr, base=None): """ Infer the dtype of the expression. """ + # TODO: Edit/fix/update according to PR #2513 dtypes = {base} - {None} - for i in expr.free_symbols: - try: - dtypes.add(i.dtype) - except AttributeError: - pass + for i in expr.args: + dtype = getattr(i, 'dtype', None) + if dtype: + dtypes.add(dtype) return infer_dtype(dtypes) diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 9e51eb9814..04da0dcc8d 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -44,7 +44,6 @@ RUN cd /tmp && mkdir openmpi && \ cd openmpi && ./autogen.pl && \ mkdir build && cd build && \ ../configure --prefix=/opt/openmpi/ \ - --disable-mpi-fortran \ --enable-mca-no-build=btl-uct --enable-mpi1-compatibility && \ make -j ${nproc} && \ make install && \ diff --git a/tests/test_petsc.py b/tests/test_petsc.py index 8261f69e43..13f0d064e5 100644 --- a/tests/test_petsc.py +++ b/tests/test_petsc.py @@ -5,7 +5,7 @@ from conftest import skipif from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig from devito.ir.iet import (Call, ElementalFunction, Definition, DummyExpr, - FindNodes, PointerCast, retrieve_iteration_tree) + FindNodes, retrieve_iteration_tree) from devito.types import Constant, CCompositeObject from devito.passes.iet.languages.C import CDataManager from devito.petsc.types import (DM, Mat, LocalVec, PetscMPIInt, KSP, @@ -51,28 +51,22 @@ def test_petsc_functions(): grid = Grid((2, 2)) x, y = grid.dimensions - ptr0 = PETScArray(name='ptr0', dimensions=grid.dimensions, dtype=np.float32) - ptr1 = PETScArray(name='ptr1', dimensions=grid.dimensions, dtype=np.float32, - is_const=True) - ptr2 = PETScArray(name='ptr2', dimensions=grid.dimensions, dtype=np.float64, - is_const=True) - ptr3 = PETScArray(name='ptr3', dimensions=grid.dimensions, dtype=np.int32) - ptr4 = PETScArray(name='ptr4', dimensions=grid.dimensions, dtype=np.int64, - is_const=True) + f0 = Function(name='f', grid=grid, space_order=2, dtype=np.float32) + f1 = Function(name='f', grid=grid, space_order=2, dtype=np.float64) + + ptr0 = PETScArray(name='ptr0', target=f0) + ptr1 = PETScArray(name='ptr1', target=f0, is_const=True) + ptr2 = PETScArray(name='ptr2', target=f1, is_const=True) defn0 = Definition(ptr0) defn1 = Definition(ptr1) defn2 = Definition(ptr2) - defn3 = Definition(ptr3) - defn4 = Definition(ptr4) expr = DummyExpr(ptr0.indexed[x, y], ptr1.indexed[x, y] + 1) assert str(defn0) == 'float *restrict ptr0_vec;' assert str(defn1) == 'const float *restrict ptr1_vec;' assert str(defn2) == 'const double *restrict ptr2_vec;' - assert str(defn3) == 'int *restrict ptr3_vec;' - assert str(defn4) == 'const long *restrict ptr4_vec;' assert str(expr) == 'ptr0[x][y] = ptr1[x][y] + 1;' @@ -86,7 +80,7 @@ def test_petsc_subs(): f1 = Function(name='f1', grid=grid, space_order=2) f2 = Function(name='f2', grid=grid, space_order=2) - arr = PETScArray(name='arr', dimensions=f2.dimensions, dtype=f2.dtype) + arr = PETScArray(name='arr', target=f2) eqn = Eq(f1, f2.laplace) eqn_subs = eqn.subs(f2, arr) @@ -129,12 +123,12 @@ def test_petsc_solve(): rhs_expr = FindNodes(Expression).visit(formrhs_callback[0]) assert str(action_expr[-1].expr.rhs) == \ - 'matvec->h_x**(-2)*x_matvec_f[x + 1, y + 2]' + \ - ' - 2.0*matvec->h_x**(-2)*x_matvec_f[x + 2, y + 2]' + \ - ' + matvec->h_x**(-2)*x_matvec_f[x + 3, y + 2]' + \ - ' + matvec->h_y**(-2)*x_matvec_f[x + 2, y + 1]' + \ - ' - 2.0*matvec->h_y**(-2)*x_matvec_f[x + 2, y + 2]' + \ - ' + matvec->h_y**(-2)*x_matvec_f[x + 2, y + 3]' + 'x_matvec_f[x + 1, y + 2]/lctx->h_x**2' + \ + ' - 2.0*x_matvec_f[x + 2, y + 2]/lctx->h_x**2' + \ + ' + x_matvec_f[x + 3, y + 2]/lctx->h_x**2' + \ + ' + x_matvec_f[x + 2, y + 1]/lctx->h_y**2' + \ + ' - 2.0*x_matvec_f[x + 2, y + 2]/lctx->h_y**2' + \ + ' + x_matvec_f[x + 2, y + 3]/lctx->h_y**2' assert str(rhs_expr[-1].expr.rhs) == 'g[x + 2, y + 2]' @@ -174,9 +168,8 @@ def test_multiple_petsc_solves(): callable_roots = [meta_call.root for meta_call in op._func_table.values()] - # One FormRHS, one MatShellMult and one FormFunction per solve - # One PopulateMatContext for all solves - assert len(callable_roots) == 7 + # One FormRHS, MatShellMult, FormFunction, PopulateMatContext per solve + assert len(callable_roots) == 8 @skipif('petsc') @@ -184,33 +177,37 @@ def test_petsc_cast(): """ Test casting of PETScArray. """ - g0 = Grid((2)) - g1 = Grid((2, 2)) - g2 = Grid((2, 2, 2)) - - arr0 = PETScArray(name='arr0', dimensions=g0.dimensions, shape=g0.shape) - arr1 = PETScArray(name='arr1', dimensions=g1.dimensions, shape=g1.shape) - arr2 = PETScArray(name='arr2', dimensions=g2.dimensions, shape=g2.shape) - - arr3 = PETScArray(name='arr3', dimensions=g1.dimensions, - shape=g1.shape, space_order=4) - - cast0 = PointerCast(arr0) - cast1 = PointerCast(arr1) - cast2 = PointerCast(arr2) - cast3 = PointerCast(arr3) - - assert str(cast0) == \ - 'float (*restrict arr0) = (float (*)) arr0_vec;' - assert str(cast1) == \ - 'float (*restrict arr1)[da_so_1_info.gxm] = ' + \ - '(float (*)[da_so_1_info.gxm]) arr1_vec;' - assert str(cast2) == \ - 'float (*restrict arr2)[da_so_1_info.gym][da_so_1_info.gxm] = ' + \ - '(float (*)[da_so_1_info.gym][da_so_1_info.gxm]) arr2_vec;' - assert str(cast3) == \ - 'float (*restrict arr3)[da_so_4_info.gxm] = ' + \ - '(float (*)[da_so_4_info.gxm]) arr3_vec;' + grid1 = Grid((2)) + grid2 = Grid((2, 2)) + grid3 = Grid((4, 5, 6)) + + f1 = Function(name='f1', grid=grid1, space_order=2) + f2 = Function(name='f2', grid=grid2, space_order=4) + f3 = Function(name='f3', grid=grid3, space_order=6) + + eqn1 = Eq(f1.laplace, 10) + eqn2 = Eq(f2.laplace, 10) + eqn3 = Eq(f3.laplace, 10) + + petsc1 = PETScSolve(eqn1, f1) + petsc2 = PETScSolve(eqn2, f2) + petsc3 = PETScSolve(eqn3, f3) + + with switchconfig(openmp=False): + op1 = Operator(petsc1, opt='noop') + op2 = Operator(petsc2, opt='noop') + op3 = Operator(petsc3, opt='noop') + + cb1 = [meta_call.root for meta_call in op1._func_table.values()] + cb2 = [meta_call.root for meta_call in op2._func_table.values()] + cb3 = [meta_call.root for meta_call in op3._func_table.values()] + + assert 'float (*restrict x_matvec_f1) = ' + \ + '(float (*)) x_matvec_f1_vec;' in str(cb1[0]) + assert 'float (*restrict x_matvec_f2)[info.gxm] = ' + \ + '(float (*)[info.gxm]) x_matvec_f2_vec;' in str(cb2[0]) + assert 'float (*restrict x_matvec_f3)[info.gym][info.gxm] = ' + \ + '(float (*)[info.gym][info.gxm]) x_matvec_f3_vec;' in str(cb3[0]) @skipif('petsc') @@ -229,8 +226,8 @@ def test_LinearSolveExpr(): assert linsolveexpr.target == f # Check the solver parameters assert linsolveexpr.solver_parameters == \ - {'ksp_type': 'gmres', 'pc_type': 'jacobi', 'ksp_rtol': 1e-07, - 'ksp_atol': 1e-50, 'ksp_divtol': 10000.0, 'ksp_max_it': 10000} + {'ksp_type': 'gmres', 'pc_type': 'jacobi', 'ksp_rtol': 1e-05, + 'ksp_atol': 1e-50, 'ksp_divtol': 100000.0, 'ksp_max_it': 10000} @skipif('petsc') @@ -258,23 +255,15 @@ def test_dmda_create(): op3 = Operator(petsc3, opt='noop') assert 'PetscCall(DMDACreate1d(PETSC_COMM_SELF,DM_BOUNDARY_GHOSTED,' + \ - '2,1,2,NULL,&(da_so_2)));' in str(op1) + '2,1,2,NULL,&(da_0)));' in str(op1) assert 'PetscCall(DMDACreate2d(PETSC_COMM_SELF,DM_BOUNDARY_GHOSTED,' + \ - 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,2,2,1,1,1,4,NULL,NULL,&(da_so_4)));' \ + 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,2,2,1,1,1,4,NULL,NULL,&(da_0)));' \ in str(op2) assert 'PetscCall(DMDACreate3d(PETSC_COMM_SELF,DM_BOUNDARY_GHOSTED,' + \ 'DM_BOUNDARY_GHOSTED,DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,6,5,4' + \ - ',1,1,1,1,6,NULL,NULL,NULL,&(da_so_6)));' in str(op3) - - # Check unique DMDA is created per grid, per space_order - f4 = Function(name='f4', grid=grid2, space_order=6) - eqn4 = Eq(f4.laplace, 10) - petsc4 = PETScSolve(eqn4, f4) - with switchconfig(openmp=False): - op4 = Operator(petsc2+petsc2+petsc4, opt='noop') - assert str(op4).count('DMDACreate2d') == 2 + ',1,1,1,1,6,NULL,NULL,NULL,&(da_0)));' in str(op3) @skipif('petsc') @@ -302,8 +291,8 @@ def test_cinterface_petsc_struct(): assert 'include "%s.h"' % name in ccode # The public `struct MatContext` only appears in the header file - assert 'struct MatContext\n{' not in ccode - assert 'struct MatContext\n{' in hcode + assert 'struct J_0_ctx\n{' not in ccode + assert 'struct J_0_ctx\n{' in hcode @skipif('petsc') @@ -578,7 +567,7 @@ def test_callback_arguments(): assert len(ff.parameters) == 4 assert str(mv.parameters) == '(J_0, X_global_0, Y_global_0)' - assert str(ff.parameters) == '(snes_0, X_global_0, Y_global_0, dummy_0)' + assert str(ff.parameters) == '(snes_0, X_global_0, F_global_0, dummy)' @skipif('petsc') @@ -662,7 +651,7 @@ def test_petsc_frees(): assert str(frees[1]) == 'PetscCall(VecDestroy(&(x_global_0)));' assert str(frees[2]) == 'PetscCall(MatDestroy(&(J_0)));' assert str(frees[3]) == 'PetscCall(SNESDestroy(&(snes_0)));' - assert str(frees[4]) == 'PetscCall(DMDestroy(&(da_so_2)));' + assert str(frees[4]) == 'PetscCall(DMDestroy(&(da_0)));' @skipif('petsc') @@ -739,10 +728,10 @@ def test_time_loop(): body1 = str(op1.body) rhs1 = str(op1._func_table['FormRHS_0'].root.ccode) - assert 'ctx.t0 = t0' in body1 - assert 'ctx.t1 = t1' not in body1 - assert 'formrhs->t0' in rhs1 - assert 'formrhs->t1' not in rhs1 + assert 'ctx0.t0 = t0' in body1 + assert 'ctx0.t1 = t1' not in body1 + assert 'lctx->t0' in rhs1 + assert 'lctx->t1' not in rhs1 # Non-modulo time stepping u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) @@ -754,8 +743,8 @@ def test_time_loop(): body2 = str(op2.body) rhs2 = str(op2._func_table['FormRHS_0'].root.ccode) - assert 'ctx.time = time' in body2 - assert 'formrhs->time' in rhs2 + assert 'ctx0.time = time' in body2 + assert 'lctx->time' in rhs2 # Modulo time stepping with more than one time step # used in one of the callback functions @@ -766,10 +755,10 @@ def test_time_loop(): body3 = str(op3.body) rhs3 = str(op3._func_table['FormRHS_0'].root.ccode) - assert 'ctx.t0 = t0' in body3 - assert 'ctx.t1 = t1' in body3 - assert 'formrhs->t0' in rhs3 - assert 'formrhs->t1' in rhs3 + assert 'ctx0.t0 = t0' in body3 + assert 'ctx0.t1 = t1' in body3 + assert 'lctx->t0' in rhs3 + assert 'lctx->t1' in rhs3 # Multiple petsc solves within the same time loop v2 = Function(name='v2', grid=grid, space_order=2) @@ -781,5 +770,5 @@ def test_time_loop(): op4 = Operator(petsc4 + petsc5) body4 = str(op4.body) - assert 'ctx.t0 = t0' in body4 - assert body4.count('ctx.t0 = t0') == 1 + assert 'ctx0.t0 = t0' in body4 + assert body4.count('ctx0.t0 = t0') == 1 diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 2bae5679c8..febec6a25e 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -14,10 +14,11 @@ CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, ccode, uxreplace, - retrieve_derivatives) + retrieve_derivatives, sympy_dtype) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, - ComponentAccess, StencilDimension, Symbol as dSymbol) + ComponentAccess, StencilDimension, Symbol as dSymbol, + CompositeObject) from devito.types.basic import AbstractSymbol @@ -249,6 +250,17 @@ def test_field_from_pointer(): # Free symbols assert ffp1.free_symbols == {s} + # Test dtype + f = dSymbol('f') + pfields = [(f._C_name, f._C_ctype)] + struct = CompositeObject('s1', 'myStruct', pfields) + ffp4 = FieldFromPointer(f, struct) + assert str(ffp4) == 's1->f' + assert ffp4.dtype == f.dtype + expr = 1/ffp4 + dtype = sympy_dtype(expr) + assert dtype == f.dtype + def test_field_from_composite(): s = Symbol('s') @@ -293,7 +305,8 @@ def test_extended_sympy_arithmetic(): # noncommutative o = Object(name='o', dtype=c_void_p) bar = FieldFromPointer('bar', o) - assert ccode(-1 + bar) == '-1 + o->bar' + # TODO: Edit/fix/update according to PR #2513 + assert ccode(-1 + bar) == 'o->bar - 1' def test_integer_abs():