Skip to content

Commit e9c679f

Browse files
committed
pass bcs to assemble instead of interpolate
1 parent 0fc8143 commit e9c679f

File tree

3 files changed

+75
-53
lines changed

3 files changed

+75
-53
lines changed

firedrake/assemble.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,8 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
612612
rank = len(expr.arguments())
613613
if rank > 2:
614614
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
615-
616615
interpolator = get_interpolator(expr)
617-
return interpolator.assemble(tensor=tensor)
616+
return interpolator.assemble(tensor=tensor, bcs=bcs)
618617
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
619618
return tensor.assign(expr)
620619
elif tensor and isinstance(expr, ufl.ZeroBaseForm):

firedrake/interpolation.py

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,12 @@ class InterpolateOptions:
8989
If ``False``, then construct the permutation matrix for interpolating
9090
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
9191
and reduce operations.
92-
bcs : Iterable[DirichletBC] or None
93-
An optional list of boundary conditions to zero-out in the
94-
output function space. Interpolator rows or columns which are
95-
associated with boundary condition nodes are zeroed out when this is
96-
specified. By default None.
9792
"""
9893
subset: op2.Subset | None = None
9994
access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None
10095
allow_missing_dofs: bool = False
10196
default_missing_val: float | None = None
10297
matfree: bool = True
103-
bcs: Iterable[DirichletBC] | None = None
10498

10599

106100
class Interpolate(ufl.Interpolate):
@@ -190,25 +184,34 @@ class Interpolator(abc.ABC):
190184
def __init__(self, expr: Interpolate):
191185
dual_arg, operand = expr.argument_slots()
192186
self.expr = expr
187+
"""The symbolic UFL Interpolate expression."""
193188
self.expr_args = expr.arguments()
189+
"""Arguments of the Interpolate expression."""
194190
self.rank = len(self.expr_args)
191+
"""Number of arguments in the Interpolate expression."""
195192
self.operand = operand
193+
"""The primal argument slot of the Interpolate expression."""
196194
self.dual_arg = dual_arg
195+
"""The dual argument slot of the Interpolate expression."""
197196
self.target_space = dual_arg.function_space().dual()
197+
"""The primal space we are interpolating into."""
198198
self.target_mesh = self.target_space.mesh()
199+
"""The domain we are interpolating into."""
199200
self.source_mesh = extract_unique_domain(operand) or self.target_mesh
201+
"""The domain we are interpolating from."""
202+
self.callable = None
203+
"""The function which performs the interpolation."""
200204

201205
# Interpolation options
202206
self.subset = expr.options.subset
203207
self.allow_missing_dofs = expr.options.allow_missing_dofs
204208
self.default_missing_val = expr.options.default_missing_val
205209
self.matfree = expr.options.matfree
206-
self.bcs = expr.options.bcs
207210
self.callable = None
208211
self.access = expr.options.access
209212

210213
@abc.abstractmethod
211-
def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = None) -> None:
214+
def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = None, bcs: Iterable[DirichletBC] | None = None) -> None:
212215
"""Builds callable to perform interpolation. Stored in ``self.callable``.
213216
214217
If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle``
@@ -220,11 +223,18 @@ def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = No
220223
----------
221224
tensor
222225
Optional tensor to store the result in, by default None.
226+
bcs
227+
An optional list of boundary conditions to zero-out in the
228+
output function space. Interpolator rows or columns which are
229+
associated with boundary condition nodes are zeroed out when this is
230+
specified. By default None.
223231
"""
224232
pass
225233

226234
def assemble(
227-
self, tensor: Function | Cofunction | MatrixBase | None = None
235+
self,
236+
tensor: Function | Cofunction | MatrixBase | None = None,
237+
bcs: Iterable[DirichletBC] | None = None
228238
) -> Function | Cofunction | MatrixBase | Number:
229239
"""Assemble the interpolation. The result depends on the rank (number of arguments)
230240
of the :class:`Interpolate` expression:
@@ -235,20 +245,24 @@ def assemble(
235245
236246
Parameters
237247
----------
238-
tensor : Function | Cofunction | MatrixBase
248+
tensor
239249
Optional tensor to store the interpolated result. For rank-2
240250
expressions this is expected to be a subclass of
241251
:class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions
242252
this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`,
243253
for forward and adjoint interpolation respectively.
244-
254+
bcs
255+
An optional list of boundary conditions to zero-out in the
256+
output function space. Interpolator rows or columns which are
257+
associated with boundary condition nodes are zeroed out when this is
258+
specified. By default None.
245259
Returns
246260
-------
247261
Function | Cofunction | MatrixBase | numbers.Number
248262
The function, cofunction, matrix, or scalar resulting from the
249263
interpolation.
250264
"""
251-
self._build_callable(tensor=tensor)
265+
self._build_callable(tensor=tensor, bcs=bcs)
252266
result = self.callable()
253267
if self.rank == 2:
254268
# Assembling the operator
@@ -258,7 +272,7 @@ def assemble(
258272
if tensor:
259273
petsc_mat.copy(tensor.petscmat)
260274
return tensor
261-
return firedrake.AssembledMatrix(self.expr_args, self.bcs, petsc_mat)
275+
return firedrake.AssembledMatrix(self.expr_args, bcs, petsc_mat)
262276
else:
263277
assert isinstance(tensor, Function | Cofunction | None)
264278
return tensor.assign(result) if tensor else result
@@ -357,8 +371,6 @@ def __init__(self, expr: Interpolate):
357371
)
358372
else:
359373
self.access = op2.WRITE
360-
if self.bcs:
361-
raise NotImplementedError("bcs not implemented for cross-mesh interpolation.")
362374
if self.target_space.ufl_element().mapping() != "identity":
363375
# Identity mapping between reference cell and physical coordinates
364376
# implies point evaluation nodes. A more general version would
@@ -440,8 +452,10 @@ def _build_symbolic_expressions(self) -> None:
440452
arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1)
441453
self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=matfree)
442454

443-
def _build_callable(self, tensor=None):
455+
def _build_callable(self, tensor=None, bcs=None):
444456
from firedrake.assemble import assemble
457+
if bcs:
458+
raise NotImplementedError("bcs not implemented for cross-mesh interpolation.")
445459
# self.expr.function_space() is None in the 0-form case
446460
V_dest = self.expr.function_space() or self.target_space
447461
f = tensor or Function(V_dest)
@@ -590,7 +604,7 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
590604
raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments")
591605
return f
592606

593-
def _build_callable(self, tensor=None) -> None:
607+
def _build_callable(self, tensor=None, bcs=None):
594608
f = tensor or self._get_tensor()
595609
op2_tensor = f if isinstance(f, op2.Mat) else f.dat
596610

@@ -618,10 +632,10 @@ def _build_callable(self, tensor=None) -> None:
618632
# Interpolate each sub expression into each function space
619633
for indices, sub_expr in expressions.items():
620634
sub_op2_tensor = op2_tensor[indices[0]] if self.rank == 1 else op2_tensor
621-
loops.extend(_build_interpolation_callables(sub_expr, sub_op2_tensor, self.access, self.subset, self.bcs))
635+
loops.extend(_build_interpolation_callables(sub_expr, sub_op2_tensor, self.access, self.subset, bcs))
622636

623-
if self.bcs and self.rank == 1:
624-
loops.extend(partial(bc.apply, f) for bc in self.bcs)
637+
if bcs and self.rank == 1:
638+
loops.extend(partial(bc.apply, f) for bc in bcs)
625639

626640
def callable(loops, f):
627641
for l in loops:
@@ -636,20 +650,16 @@ class VomOntoVomInterpolator(SameMeshInterpolator):
636650
def __init__(self, expr: Interpolate):
637651
super().__init__(expr)
638652

639-
def _build_callable(self, tensor=None):
653+
def _build_callable(self, tensor=None, bcs=None):
654+
if bcs:
655+
raise NotImplementedError("bcs not implemented for vom-to-vom interpolation.")
640656
self.mat = VomOntoVomMat(self)
641-
if self.rank == 2:
642-
# We make our own linear operator for this case using PETSc SFs
643-
op2_tensor = None
644-
else:
657+
if self.rank == 1:
645658
f = tensor or self._get_tensor()
646-
op2_tensor = f.dat
647-
# NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
648-
# data, including the correct data size and dimensional information
649-
# (so for vector function spaces in 2 dimensions we might need a
650-
# concatenation of 2 MPI.DOUBLE types when we are in real mode)
651-
if op2_tensor is not None:
652-
assert self.rank == 1
659+
# NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
660+
# data, including the correct data size and dimensional information
661+
# (so for vector function spaces in 2 dimensions we might need a
662+
# concatenation of 2 MPI.DOUBLE types when we are in real mode)
653663
self.mat.mpi_type = get_dat_mpi_type(f.dat)[0]
654664
if self.expr.is_adjoint:
655665
assert isinstance(self.dual_arg, ufl.Cofunction)
@@ -667,8 +677,7 @@ def callable() -> Function:
667677
with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec:
668678
self.mat.handle.mult(coeff_vec, target_vec)
669679
return f
670-
else:
671-
assert self.rank == 2
680+
elif self.rank == 2:
672681
# we know we will be outputting either a function or a cofunction,
673682
# both of which will use a dat as a data carrier. At present, the
674683
# data type does not depend on function space dimension, so we can
@@ -1471,6 +1480,22 @@ def __init__(self, expr: Interpolate):
14711480
"""
14721481
super().__init__(expr)
14731482

1483+
def _get_sub_interpolators(self, bcs: Iterable[DirichletBC] | None = None) -> dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]:
1484+
"""Gets `Interpolator`s for each sub-Interpolate in the mixed expression.
1485+
1486+
Returns
1487+
-------
1488+
dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]
1489+
A map from block index tuples to `Interpolator`s and bcs.
1490+
"""
1491+
# Get the primal spaces
1492+
spaces = tuple(
1493+
a.function_space().dual() if isinstance(a, Coargument) else a.function_space() for a in self.expr_args
1494+
)
1495+
# TODO consider a stricter equality test for indexed MixedFunctionSpace
1496+
# See https://github.com/firedrakeproject/firedrake/issues/4668
1497+
space_equals = lambda V1, V2: V1 == V2 and V1.parent == V2.parent and V1.index == V2.index
1498+
14741499
# We need a Coargument in order to split the Interpolate
14751500
needs_action = not any(isinstance(a, Coargument) for a in self.expr_args)
14761501
if needs_action:
@@ -1480,33 +1505,31 @@ def __init__(self, expr: Interpolate):
14801505
self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space)
14811506

14821507
# Get sub-interpolators for each block
1483-
self.Isub: dict[tuple[int, int], Interpolator] = {}
1508+
Isub: dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]] = {}
14841509
for indices, form in firedrake.formmanipulation.split_form(self.expr):
14851510
if isinstance(form, ufl.ZeroBaseForm):
14861511
# Ensure block sparsity
14871512
continue
1488-
vi, _ = form.argument_slots()
1489-
Vtarget = vi.function_space().dual()
1490-
if self.bcs and self.rank != 0:
1491-
args = form.arguments()
1492-
Vsource = args[1 - vi.number()].function_space()
1493-
sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}]
1494-
else:
1495-
sub_bcs = None
1513+
sub_bcs = []
1514+
for space, index in zip(spaces, indices):
1515+
subspace = space.sub(index)
1516+
sub_bcs.extend(bc for bc in bcs if space_equals(bc.function_space(), subspace))
14961517
if needs_action:
14971518
# Take the action of each sub-cofunction against each block
14981519
form = action(form, dual_split[indices[-1:]])
1499-
form.options.bcs = sub_bcs
1500-
self.Isub[indices] = get_interpolator(form)
1520+
Isub[indices] = (get_interpolator(form), sub_bcs)
1521+
1522+
return Isub
15011523

1502-
def _build_callable(self, tensor=None):
1524+
def _build_callable(self, tensor=None, bcs=None):
1525+
Isub = self._get_sub_interpolators(bcs=bcs)
15031526
V_dest = self.expr.function_space() or self.target_space
15041527
f = tensor or Function(V_dest)
15051528
if self.rank == 2:
15061529
shape = tuple(len(a.function_space()) for a in self.expr_args)
15071530
blocks = numpy.full(shape, PETSc.Mat(), dtype=object)
1508-
for indices, interp in self.Isub.items():
1509-
interp._build_callable()
1531+
for indices, (interp, sub_bcs) in Isub.items():
1532+
interp._build_callable(bcs=sub_bcs)
15101533
blocks[indices] = interp.callable().handle
15111534
self.handle = PETSc.Mat().createNest(blocks)
15121535

@@ -1516,10 +1539,10 @@ def callable() -> MixedInterpolator:
15161539
def callable() -> Function | Cofunction:
15171540
for k, sub_tensor in enumerate(f.subfunctions):
15181541
sub_tensor.assign(sum(
1519-
interp.assemble() for indices, interp in self.Isub.items() if indices[0] == k
1542+
interp.assemble(bcs=sub_bcs) for indices, (interp, sub_bcs) in Isub.items() if indices[0] == k
15201543
))
15211544
return f
15221545
else:
15231546
def callable() -> Number:
1524-
return sum(interp.assemble() for interp in self.Isub.values())
1547+
return sum(interp.assemble(bcs=sub_bcs) for (interp, sub_bcs) in Isub.values())
15251548
self.callable = callable

firedrake/preconditioners/hiptmair.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def coarsen(self, pc):
202202

203203
coarse_space_bcs = tuple(coarse_space_bcs)
204204
if G_callback is None:
205-
interp_petscmat = chop(assemble(interpolate(dminus(trial), V, bcs=bcs + coarse_space_bcs)).mat())
205+
interp_petscmat = chop(assemble(interpolate(dminus(trial), V), bcs=bcs + coarse_space_bcs)).mat()
206206
else:
207207
interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs)
208208

0 commit comments

Comments
 (0)