Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def _gen_struct_decl(self, obj, masked=()):
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
if isinstance(obj, LocalCompositeObject):
# TODO: Potentially re-evaluate: Setting ctype to obj allows
# TODO: re-evaluate: Setting ctype to obj allows
# _gen_struct_decl to generate a cgen.Structure from a
# LocalCompositeObject, where obj._C_ctype is a CustomDtype.
# LocalCompositeObject has a __fields__ property,
# which allows the subsequent code in this function to function
# which allows the subsequent code in this function to work
# correctly.
ctype = obj
else:
Expand Down
161 changes: 126 additions & 35 deletions devito/petsc/iet/routines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from functools import cached_property
import math

from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody,
Dereference, DummyExpr, BlankLine, Callable, FindNodes,
Expand Down Expand Up @@ -1111,8 +1112,19 @@ def _setup(self):
global_x = petsc_call('DMCreateGlobalVector',
[dmda, Byref(sobjs['xglobal'])])

local_x = petsc_call('DMCreateLocalVector',
[dmda, Byref(sobjs['xlocal'])])
target = self.fielddata.target
field_from_ptr = FieldFromPointer(
target.function._C_field_data, target.function._C_symbol
)

local_size = math.prod(
v for v, dim in zip(target.shape_allocated, target.dimensions) if dim.is_Space
)
local_x = petsc_call('VecCreateMPIWithArray',
['PETSC_COMM_WORLD', 1, local_size, 'PETSC_DECIDE',
field_from_ptr, Byref(sobjs['xlocal'])])

# TODO: potentially also need to set the DM and local/global map to xlocal

get_local_size = petsc_call('VecGetSize',
[sobjs['xlocal'], Byref(sobjs['localsize'])])
Expand Down Expand Up @@ -1247,11 +1259,87 @@ class CoupledSetup(BaseSetup):
def snes_ctx(self):
return Byref(self.solver_objs['jacctx'])

def _extend_setup(self):
def _setup(self):
# TODO: minimise code duplication with superclass
objs = self.objs
sobjs = self.solver_objs

dmda = sobjs['dmda']

solver_params = self.injectsolve.expr.rhs.solver_parameters

snes_create = petsc_call('SNESCreate', [objs['comm'], Byref(sobjs['snes'])])

snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])

create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])])

# NOTE: Assuming all solves are linear for now
snes_set_type = petsc_call('SNESSetType', [sobjs['snes'], 'SNESKSPONLY'])

snes_set_jac = petsc_call(
'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'],
sobjs['Jac'], 'MatMFFDComputeJacobian', objs['Null']]
)

global_x = petsc_call('DMCreateGlobalVector',
[dmda, Byref(sobjs['xglobal'])])

local_x = petsc_call('DMCreateLocalVector', [dmda, Byref(sobjs['xlocal'])])

get_local_size = petsc_call('VecGetSize',
[sobjs['xlocal'], Byref(sobjs['localsize'])])

global_b = petsc_call('DMCreateGlobalVector',
[dmda, Byref(sobjs['bglobal'])])

snes_get_ksp = petsc_call('SNESGetKSP',
[sobjs['snes'], Byref(sobjs['ksp'])])

ksp_set_tols = petsc_call(
'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
solver_params['ksp_atol'], solver_params['ksp_divtol'],
solver_params['ksp_max_it']]
)

ksp_set_type = petsc_call(
'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
)

ksp_get_pc = petsc_call(
'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])]
)

# Even though the default will be jacobi, set to PCNONE for now
pc_set_type = petsc_call('PCSetType', [sobjs['pc'], 'PCNONE'])

ksp_set_from_ops = petsc_call('KSPSetFromOptions', [sobjs['ksp']])

matvec = self.cbbuilder.main_matvec_callback
matvec_operation = petsc_call(
'MatShellSetOperation',
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
)
formfunc = self.cbbuilder.main_formfunc_callback
formfunc_operation = petsc_call(
'SNESSetFunction',
[sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void),
self.snes_ctx]
)

dmda_calls = self._create_dmda_calls(dmda)

mainctx = sobjs['userctx']

call_struct_callback = petsc_call(
self.cbbuilder.user_struct_callback.name, [Byref(mainctx)]
)

# TODO: maybe don't need to explictly set this
mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda])

calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])

create_field_decomp = petsc_call(
'DMCreateFieldDecomposition',
[dmda, Byref(sobjs['nfields']), objs['Null'], Byref(sobjs['fields']),
Expand Down Expand Up @@ -1297,13 +1385,34 @@ def _extend_setup(self):
[sobjs[f'da{t.name}'], Byref(sobjs[f'bglobal{t.name}'])]
) for t in targets]

return (
coupled_setup = dmda_calls + (
snes_create,
snes_set_dm,
create_matrix,
snes_set_jac,
snes_set_type,
global_x,
local_x,
get_local_size,
global_b,
snes_get_ksp,
ksp_set_tols,
ksp_set_type,
ksp_get_pc,
pc_set_type,
ksp_set_from_ops,
matvec_operation,
formfunc_operation,
call_struct_callback,
mat_set_dm,
calls_set_app_ctx,
create_field_decomp,
matop_create_submats_op,
call_coupled_struct_callback,
shell_set_ctx,
create_submats
) + tuple(deref_dms) + tuple(xglobals) + tuple(bglobals)
create_submats) + \
tuple(deref_dms) + tuple(xglobals) + tuple(bglobals)
return coupled_setup


class Solver:
Expand Down Expand Up @@ -1333,7 +1442,7 @@ def _execute_solve(self):

rhs_call = petsc_call(rhs_callback.name, [sobjs['dmda'], sobjs['bglobal']])

vec_replace_array = self.timedep.replace_array(target)
vec_place_array = self.timedep.place_array(target)

if self.cbbuilder.initialguesses:
initguess = self.cbbuilder.initialguesses[0]
Expand All @@ -1358,7 +1467,7 @@ def _execute_solve(self):

run_solver_calls = (struct_assignment,) + (
rhs_call,
) + vec_replace_array + (
) + vec_place_array + (
initguess_call,
dm_local_to_global_x,
snes_solve,
Expand Down Expand Up @@ -1415,7 +1524,7 @@ def _execute_solve(self):
pre_solve += (
petsc_call(c.name, [dm, target_bglob]),
petsc_call('DMCreateLocalVector', [dm, Byref(target_xloc)]),
self.timedep.replace_array(t),
self.timedep.place_array(t),
petsc_call(
'DMLocalToGlobal',
[dm, target_xloc, insert_vals, target_xglob]
Expand Down Expand Up @@ -1485,23 +1594,7 @@ def _origin_to_moddim_mapper(self, iters):
def uxreplace_time(self, body):
return body

def replace_array(self, target):
"""
VecReplaceArray() is a PETSc function that allows replacing the array
of a `Vec` with a user provided array.
https://petsc.org/release/manualpages/Vec/VecReplaceArray/

This function is used to replace the array of the PETSc solution `Vec`
with the array from the `Function` object representing the target.

Examples
--------
>>> target
f1(x, y)
>>> call = replace_array(target)
>>> print(call)
PetscCall(VecReplaceArray(xlocal0,f1_vec->data));
"""
def place_array(self, target):
sobjs = self.sobjs

field_from_ptr = FieldFromPointer(
Expand Down Expand Up @@ -1608,29 +1701,27 @@ def _origin_to_moddim_mapper(self, iters):
mapper[d] = d
return mapper

def replace_array(self, target):
def place_array(self, target):
"""
In the case that the actual target is time-dependent e.g a `TimeFunction`,
a pointer to the first element in the array that will be updated during
the time step is passed to VecReplaceArray().
the time step is passed to VecPlaceArray().

Examples
--------
>>> target
f1(time + dt, x, y)
>>> calls = replace_array(target)
>>> calls = place_array(target)
>>> print(List(body=calls))
PetscCall(VecGetSize(xlocal0,&(localsize0)));
float * f1_ptr0 = (time + 1)*localsize0 + (float*)(f1_vec->data);
PetscCall(VecReplaceArray(xlocal0,f1_ptr0));
PetscCall(VecPlaceArray(xlocal0,f1_ptr0));

>>> target
f1(t + dt, x, y)
>>> calls = replace_array(target)
>>> calls = place_array(target)
>>> print(List(body=calls))
PetscCall(VecGetSize(xlocal0,&(localsize0)));
float * f1_ptr0 = t1*localsize0 + (float*)(f1_vec->data);
PetscCall(VecReplaceArray(xlocal0,f1_ptr0));
PetscCall(VecPlaceArray(xlocal0,f1_ptr0));
"""
sobjs = self.sobjs

Expand All @@ -1654,7 +1745,7 @@ def replace_array(self, target):
),
petsc_call('VecPlaceArray', [xlocal, start_ptr])
)
return super().replace_array(target)
return super().place_array(target)

def assign_time_iters(self, struct):
"""
Expand Down
2 changes: 1 addition & 1 deletion devito/petsc/types/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def dofs(self):
def _C_free(self):
return petsc_call('DMDestroy', [Byref(self.function)])

# TODO: This is growing out of hand so switch to an enumeration or something?
# TODO: Switch to an enumeration?
@property
def _C_free_priority(self):
return 4
Expand Down
2 changes: 2 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def test_extended_sympy_arithmetic():
o = Object(name='o', dtype=c_void_p)
bar = FieldFromPointer('bar', o)
# TODO: Edit/fix/update according to PR #2513
# The order changed due to adding the dtype property
# to FieldFromPointer
assert ccode(-1 + bar) == 'o->bar - 1'


Expand Down
Loading