Skip to content

Commit 15bdd53

Browse files
committed
pass bcs to interpolate, zero cofunction fix
fix
1 parent 6b87a6b commit 15bdd53

File tree

2 files changed

+33
-39
lines changed

2 files changed

+33
-39
lines changed

firedrake/interpolation.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,18 @@ class InterpolateOptions:
8787
If ``False``, then construct the permutation matrix for interpolating
8888
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
8989
and reduce operations.
90+
bcs : Iterable[BCBase] | None, optional
91+
An optional list of boundary conditions to zero-out in the
92+
output function space. Interpolator rows or columns which are
93+
associated with boundary condition nodes are zeroed out when this is
94+
specified. By default None.
9095
"""
9196
subset: op2.Subset | None = None
9297
access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None
9398
allow_missing_dofs: bool = False
9499
default_missing_val: float | None = None
95100
matfree: bool = True
101+
bcs: Iterable[BCBase] | None = None
96102

97103

98104
class Interpolate(ufl.Interpolate):
@@ -162,18 +168,13 @@ def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs) -> Interpo
162168
return Interpolate(expr, V, **kwargs)
163169

164170

165-
def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) -> "Interpolator":
171+
def get_interpolator(expr: Interpolate) -> "Interpolator":
166172
"""Create an Interpolator.
167173
168174
Parameters
169175
----------
170176
expr : Interpolate
171177
Symbolic interpolation expression.
172-
bcs : Iterable[BCBase] | None, optional
173-
An optional list of boundary conditions to zero-out in the
174-
output function space. Interpolator rows or columns which are
175-
associated with boundary condition nodes are zeroed out when this is
176-
specified. By default None.
177178
178179
Returns
179180
-------
@@ -183,7 +184,7 @@ def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) ->
183184
arguments = expr.arguments()
184185
has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments)
185186
if len(arguments) == 2 and has_mixed_arguments:
186-
return MixedInterpolator(expr, bcs=bcs)
187+
return MixedInterpolator(expr)
187188

188189
operand, = expr.ufl_operands
189190
target_mesh = expr.target_space.mesh()
@@ -194,24 +195,24 @@ def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) ->
194195
and target_mesh.topological_dimension() == source_mesh.topological_dimension()
195196
)
196197
if target_mesh is source_mesh or submesh_interp_implemented:
197-
return SameMeshInterpolator(expr, bcs=bcs)
198+
return SameMeshInterpolator(expr)
198199

199200
target_topology = target_mesh.topology
200201
source_topology = source_mesh.topology
201202

202203
if isinstance(target_topology, VertexOnlyMeshTopology):
203204
if isinstance(source_topology, VertexOnlyMeshTopology):
204-
return VomOntoVomInterpolator(expr, bcs=bcs)
205+
return VomOntoVomInterpolator(expr)
205206
if target_mesh.geometric_dimension() != source_mesh.geometric_dimension():
206207
raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension")
207208
if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh:
208209
raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target")
209-
return SameMeshInterpolator(expr, bcs=bcs)
210+
return SameMeshInterpolator(expr)
210211

211212
if has_mixed_arguments or len(expr.target_space) > 1:
212-
return MixedInterpolator(expr, bcs=bcs)
213+
return MixedInterpolator(expr)
213214

214-
return CrossMeshInterpolator(expr, bcs=bcs)
215+
return CrossMeshInterpolator(expr)
215216

216217

217218
class Interpolator(abc.ABC):
@@ -222,14 +223,9 @@ class Interpolator(abc.ABC):
222223
----------
223224
expr : Interpolate
224225
The symbolic interpolation expression.
225-
bcs : Iterable[BCBase], optional
226-
An optional list of boundary conditions to zero-out in the
227-
output function space. Interpolator rows or columns which are
228-
associated with boundary condition nodes are zeroed out when this is
229-
specified. By default None.
230226
231227
"""
232-
def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None):
228+
def __init__(self, expr: Interpolate):
233229
dual_arg, operand = expr.argument_slots()
234230
self.expr = expr
235231
self.expr_args = expr.arguments()
@@ -245,7 +241,7 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None):
245241
self.allow_missing_dofs = expr.options.allow_missing_dofs
246242
self.default_missing_val = expr.options.default_missing_val
247243
self.matfree = expr.options.matfree
248-
self.bcs = bcs
244+
self.bcs = expr.options.bcs
249245
self.callable = None
250246
self.access = expr.options.access
251247

@@ -345,8 +341,8 @@ class CrossMeshInterpolator(Interpolator):
345341
"""
346342

347343
@no_annotations
348-
def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None):
349-
super().__init__(expr, bcs)
344+
def __init__(self, expr: Interpolate):
345+
super().__init__(expr)
350346
if self.access and self.access != op2.WRITE:
351347
raise NotImplementedError(
352348
"Access other than op2.WRITE not implemented for cross-mesh interpolation."
@@ -520,8 +516,8 @@ class SameMeshInterpolator(Interpolator):
520516
"""
521517

522518
@no_annotations
523-
def __init__(self, expr, bcs=None):
524-
super().__init__(expr, bcs=bcs)
519+
def __init__(self, expr):
520+
super().__init__(expr)
525521
subset = self.subset
526522
if subset is None:
527523
target = self.target_mesh.topology
@@ -594,9 +590,6 @@ def _build_callable(self, output=None) -> None:
594590

595591
loops = []
596592

597-
if self.access == op2.INC:
598-
loops.append(tensor.zero)
599-
600593
# Arguments in the operand are allowed to be from a MixedFunctionSpace
601594
# We need to split the target space V and generate separate kernels
602595
if self.rank == 2:
@@ -635,8 +628,8 @@ def callable(loops, f):
635628

636629
class VomOntoVomInterpolator(SameMeshInterpolator):
637630

638-
def __init__(self, expr: Interpolate, bcs=None):
639-
super().__init__(expr, bcs=bcs)
631+
def __init__(self, expr: Interpolate):
632+
super().__init__(expr)
640633

641634
def _build_callable(self, output=None):
642635
self.mat = VomOntoVomMat(self)
@@ -899,7 +892,10 @@ def _build_interpolation_callables(
899892
if isinstance(tensor, op2.Mat):
900893
return parloop_compute_callable, tensor.assemble
901894
else:
902-
return copyin + callables + (parloop_compute_callable, ) + copyout
895+
extra = copyin + callables
896+
if access == op2.INC:
897+
extra += (tensor.zero,)
898+
return extra + (parloop_compute_callable, ) + copyout
903899

904900

905901
def get_interp_node_map(source_mesh: MeshGeometry, target_mesh: MeshGeometry, fs: WithGeometry) -> op2.Map | None:
@@ -1446,11 +1442,9 @@ class MixedInterpolator(Interpolator):
14461442
V
14471443
The :class:`.FunctionSpace` or :class:`.Function` to
14481444
interpolate into.
1449-
bcs
1450-
A list of boundary conditions.
14511445
"""
1452-
def __init__(self, expr, bcs=None):
1453-
super().__init__(expr, bcs=bcs)
1446+
def __init__(self, expr):
1447+
super().__init__(expr)
14541448

14551449
# We need a Coargument in order to split the Interpolate
14561450
needs_action = not any(isinstance(a, Coargument) for a in self.expr_args)
@@ -1467,17 +1461,17 @@ def __init__(self, expr, bcs=None):
14671461
continue
14681462
vi, _ = form.argument_slots()
14691463
Vtarget = vi.function_space().dual()
1470-
if bcs and self.rank != 0:
1464+
if self.bcs and self.rank != 0:
14711465
args = form.arguments()
14721466
Vsource = args[1 - vi.number()].function_space()
1473-
sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}]
1467+
sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}]
14741468
else:
14751469
sub_bcs = None
14761470
if needs_action:
14771471
# Take the action of each sub-cofunction against each block
14781472
form = action(form, dual_split[indices[-1:]])
1479-
1480-
Isub[indices] = get_interpolator(form, bcs=sub_bcs)
1473+
form.options.bcs = sub_bcs
1474+
Isub[indices] = get_interpolator(form)
14811475

14821476
self._sub_interpolators = Isub
14831477

firedrake/preconditioners/hiptmair.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from firedrake.preconditioners.hypre_ams import chop
1111
from firedrake.preconditioners.facet_split import restrict
1212
from firedrake.parameters import parameters
13-
from firedrake.interpolation import Interpolator
13+
from firedrake.interpolation import interpolate
1414
from ufl.algorithms.ad import expand_derivatives
1515
import firedrake.dmhooks as dmhooks
1616
import firedrake.utils as utils
@@ -202,7 +202,7 @@ def coarsen(self, pc):
202202

203203
coarse_space_bcs = tuple(coarse_space_bcs)
204204
if G_callback is None:
205-
interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle)
205+
interp_petscmat = chop(assemble(interpolate(dminus(trial), V, bcs=bcs + coarse_space_bcs)).mat())
206206
else:
207207
interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs)
208208

0 commit comments

Comments
 (0)