@@ -89,18 +89,12 @@ class InterpolateOptions:
8989 If ``False``, then construct the permutation matrix for interpolating
9090 between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
9191 and reduce operations.
92- bcs : Iterable[DirichletBC] or None
93- An optional list of boundary conditions to zero-out in the
94- output function space. Interpolator rows or columns which are
95- associated with boundary condition nodes are zeroed out when this is
96- specified. By default None.
9792 """
9893 subset : op2 .Subset | None = None
9994 access : Literal [op2 .WRITE , op2 .MIN , op2 .MAX , op2 .INC ] | None = None
10095 allow_missing_dofs : bool = False
10196 default_missing_val : float | None = None
10297 matfree : bool = True
103- bcs : Iterable [DirichletBC ] | None = None
10498
10599
106100class Interpolate (ufl .Interpolate ):
@@ -190,25 +184,34 @@ class Interpolator(abc.ABC):
190184 def __init__ (self , expr : Interpolate ):
191185 dual_arg , operand = expr .argument_slots ()
192186 self .expr = expr
187+ """The symbolic UFL Interpolate expression."""
193188 self .expr_args = expr .arguments ()
189+ """Arguments of the Interpolate expression."""
194190 self .rank = len (self .expr_args )
191+ """Number of arguments in the Interpolate expression."""
195192 self .operand = operand
193+ """The primal argument slot of the Interpolate expression."""
196194 self .dual_arg = dual_arg
195+ """The dual argument slot of the Interpolate expression."""
197196 self .target_space = dual_arg .function_space ().dual ()
197+ """The primal space we are interpolating into."""
198198 self .target_mesh = self .target_space .mesh ()
199+ """The domain we are interpolating into."""
199200 self .source_mesh = extract_unique_domain (operand ) or self .target_mesh
201+ """The domain we are interpolating from."""
202+ self .callable = None
203+ """The function which performs the interpolation."""
200204
201205 # Interpolation options
202206 self .subset = expr .options .subset
203207 self .allow_missing_dofs = expr .options .allow_missing_dofs
204208 self .default_missing_val = expr .options .default_missing_val
205209 self .matfree = expr .options .matfree
206- self .bcs = expr .options .bcs
207210 self .callable = None
208211 self .access = expr .options .access
209212
210213 @abc .abstractmethod
211- def _build_callable (self , tensor : Function | Cofunction | MatrixBase | None = None ) -> None :
214+ def _build_callable (self , tensor : Function | Cofunction | MatrixBase | None = None , bcs : Iterable [ DirichletBC ] | None = None ) -> None :
212215 """Builds callable to perform interpolation. Stored in ``self.callable``.
213216
214217 If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle``
@@ -220,11 +223,18 @@ def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = No
220223 ----------
221224 tensor
222225 Optional tensor to store the result in, by default None.
226+ bcs
227+ An optional list of boundary conditions to zero-out in the
228+ output function space. Interpolator rows or columns which are
229+ associated with boundary condition nodes are zeroed out when this is
230+ specified. By default None.
223231 """
224232 pass
225233
226234 def assemble (
227- self , tensor : Function | Cofunction | MatrixBase | None = None
235+ self ,
236+ tensor : Function | Cofunction | MatrixBase | None = None ,
237+ bcs : Iterable [DirichletBC ] | None = None
228238 ) -> Function | Cofunction | MatrixBase | Number :
229239 """Assemble the interpolation. The result depends on the rank (number of arguments)
230240 of the :class:`Interpolate` expression:
@@ -235,20 +245,24 @@ def assemble(
235245
236246 Parameters
237247 ----------
238- tensor : Function | Cofunction | MatrixBase
248+ tensor
239249 Optional tensor to store the interpolated result. For rank-2
240250 expressions this is expected to be a subclass of
241251 :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions
242252 this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`,
243253 for forward and adjoint interpolation respectively.
244-
254+ bcs
255+ An optional list of boundary conditions to zero-out in the
256+ output function space. Interpolator rows or columns which are
257+ associated with boundary condition nodes are zeroed out when this is
258+ specified. By default None.
245259 Returns
246260 -------
247261 Function | Cofunction | MatrixBase | numbers.Number
248262 The function, cofunction, matrix, or scalar resulting from the
249263 interpolation.
250264 """
251- self ._build_callable (tensor = tensor )
265+ self ._build_callable (tensor = tensor , bcs = bcs )
252266 result = self .callable ()
253267 if self .rank == 2 :
254268 # Assembling the operator
@@ -258,7 +272,7 @@ def assemble(
258272 if tensor :
259273 petsc_mat .copy (tensor .petscmat )
260274 return tensor
261- return firedrake .AssembledMatrix (self .expr_args , self . bcs , petsc_mat )
275+ return firedrake .AssembledMatrix (self .expr_args , bcs , petsc_mat )
262276 else :
263277 assert isinstance (tensor , Function | Cofunction | None )
264278 return tensor .assign (result ) if tensor else result
@@ -357,8 +371,6 @@ def __init__(self, expr: Interpolate):
357371 )
358372 else :
359373 self .access = op2 .WRITE
360- if self .bcs :
361- raise NotImplementedError ("bcs not implemented for cross-mesh interpolation." )
362374 if self .target_space .ufl_element ().mapping () != "identity" :
363375 # Identity mapping between reference cell and physical coordinates
364376 # implies point evaluation nodes. A more general version would
@@ -440,8 +452,10 @@ def _build_symbolic_expressions(self) -> None:
440452 arg = Argument (self .P0DG_vom , 0 if self .expr .is_adjoint else 1 )
441453 self .point_eval_input_ordering = interpolate (arg , self .P0DG_vom_input_ordering , matfree = matfree )
442454
443- def _build_callable (self , tensor = None ):
455+ def _build_callable (self , tensor = None , bcs = None ):
444456 from firedrake .assemble import assemble
457+ if bcs :
458+ raise NotImplementedError ("bcs not implemented for cross-mesh interpolation." )
445459 # self.expr.function_space() is None in the 0-form case
446460 V_dest = self .expr .function_space () or self .target_space
447461 f = tensor or Function (V_dest )
@@ -590,7 +604,7 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
590604 raise ValueError (f"Cannot interpolate an expression with { self .rank } arguments" )
591605 return f
592606
593- def _build_callable (self , tensor = None ) -> None :
607+ def _build_callable (self , tensor = None , bcs = None ) :
594608 f = tensor or self ._get_tensor ()
595609 op2_tensor = f if isinstance (f , op2 .Mat ) else f .dat
596610
@@ -618,10 +632,10 @@ def _build_callable(self, tensor=None) -> None:
618632 # Interpolate each sub expression into each function space
619633 for indices , sub_expr in expressions .items ():
620634 sub_op2_tensor = op2_tensor [indices [0 ]] if self .rank == 1 else op2_tensor
621- loops .extend (_build_interpolation_callables (sub_expr , sub_op2_tensor , self .access , self .subset , self . bcs ))
635+ loops .extend (_build_interpolation_callables (sub_expr , sub_op2_tensor , self .access , self .subset , bcs ))
622636
623- if self . bcs and self .rank == 1 :
624- loops .extend (partial (bc .apply , f ) for bc in self . bcs )
637+ if bcs and self .rank == 1 :
638+ loops .extend (partial (bc .apply , f ) for bc in bcs )
625639
626640 def callable (loops , f ):
627641 for l in loops :
@@ -636,20 +650,16 @@ class VomOntoVomInterpolator(SameMeshInterpolator):
636650 def __init__ (self , expr : Interpolate ):
637651 super ().__init__ (expr )
638652
639- def _build_callable (self , tensor = None ):
653+ def _build_callable (self , tensor = None , bcs = None ):
654+ if bcs :
655+ raise NotImplementedError ("bcs not implemented for vom-to-vom interpolation." )
640656 self .mat = VomOntoVomMat (self )
641- if self .rank == 2 :
642- # We make our own linear operator for this case using PETSc SFs
643- op2_tensor = None
644- else :
657+ if self .rank == 1 :
645658 f = tensor or self ._get_tensor ()
646- op2_tensor = f .dat
647- # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
648- # data, including the correct data size and dimensional information
649- # (so for vector function spaces in 2 dimensions we might need a
650- # concatenation of 2 MPI.DOUBLE types when we are in real mode)
651- if op2_tensor is not None :
652- assert self .rank == 1
659+ # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
660+ # data, including the correct data size and dimensional information
661+ # (so for vector function spaces in 2 dimensions we might need a
662+ # concatenation of 2 MPI.DOUBLE types when we are in real mode)
653663 self .mat .mpi_type = get_dat_mpi_type (f .dat )[0 ]
654664 if self .expr .is_adjoint :
655665 assert isinstance (self .dual_arg , ufl .Cofunction )
@@ -667,8 +677,7 @@ def callable() -> Function:
667677 with coeff .dat .vec_ro as coeff_vec , f .dat .vec_wo as target_vec :
668678 self .mat .handle .mult (coeff_vec , target_vec )
669679 return f
670- else :
671- assert self .rank == 2
680+ elif self .rank == 2 :
672681 # we know we will be outputting either a function or a cofunction,
673682 # both of which will use a dat as a data carrier. At present, the
674683 # data type does not depend on function space dimension, so we can
@@ -1471,6 +1480,22 @@ def __init__(self, expr: Interpolate):
14711480 """
14721481 super ().__init__ (expr )
14731482
1483+ def _get_sub_interpolators (self , bcs : Iterable [DirichletBC ] | None = None ) -> dict [tuple [int , int ], tuple [Interpolator , list [DirichletBC ]]]:
1484+ """Gets `Interpolator`s for each sub-Interpolate in the mixed expression.
1485+
1486+ Returns
1487+ -------
1488+ dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]
1489+ A map from block index tuples to `Interpolator`s and bcs.
1490+ """
1491+ # Get the primal spaces
1492+ spaces = tuple (
1493+ a .function_space ().dual () if isinstance (a , Coargument ) else a .function_space () for a in self .expr_args
1494+ )
1495+ # TODO consider a stricter equality test for indexed MixedFunctionSpace
1496+ # See https://github.com/firedrakeproject/firedrake/issues/4668
1497+ space_equals = lambda V1 , V2 : V1 == V2 and V1 .parent == V2 .parent and V1 .index == V2 .index
1498+
14741499 # We need a Coargument in order to split the Interpolate
14751500 needs_action = not any (isinstance (a , Coargument ) for a in self .expr_args )
14761501 if needs_action :
@@ -1480,33 +1505,31 @@ def __init__(self, expr: Interpolate):
14801505 self .expr = self .expr ._ufl_expr_reconstruct_ (self .operand , self .target_space )
14811506
14821507 # Get sub-interpolators for each block
1483- self . Isub : dict [tuple [int , int ], Interpolator ] = {}
1508+ Isub : dict [tuple [int , int ], tuple [ Interpolator , list [ DirichletBC ]] ] = {}
14841509 for indices , form in firedrake .formmanipulation .split_form (self .expr ):
14851510 if isinstance (form , ufl .ZeroBaseForm ):
14861511 # Ensure block sparsity
14871512 continue
1488- vi , _ = form .argument_slots ()
1489- Vtarget = vi .function_space ().dual ()
1490- if self .bcs and self .rank != 0 :
1491- args = form .arguments ()
1492- Vsource = args [1 - vi .number ()].function_space ()
1493- sub_bcs = [bc for bc in self .bcs if bc .function_space () in {Vsource , Vtarget }]
1494- else :
1495- sub_bcs = None
1513+ sub_bcs = []
1514+ for space , index in zip (spaces , indices ):
1515+ subspace = space .sub (index )
1516+ sub_bcs .extend (bc for bc in bcs if space_equals (bc .function_space (), subspace ))
14961517 if needs_action :
14971518 # Take the action of each sub-cofunction against each block
14981519 form = action (form , dual_split [indices [- 1 :]])
1499- form .options .bcs = sub_bcs
1500- self .Isub [indices ] = get_interpolator (form )
1520+ Isub [indices ] = (get_interpolator (form ), sub_bcs )
1521+
1522+ return Isub
15011523
1502- def _build_callable (self , tensor = None ):
1524+ def _build_callable (self , tensor = None , bcs = None ):
1525+ Isub = self ._get_sub_interpolators (bcs = bcs )
15031526 V_dest = self .expr .function_space () or self .target_space
15041527 f = tensor or Function (V_dest )
15051528 if self .rank == 2 :
15061529 shape = tuple (len (a .function_space ()) for a in self .expr_args )
15071530 blocks = numpy .full (shape , PETSc .Mat (), dtype = object )
1508- for indices , interp in self . Isub .items ():
1509- interp ._build_callable ()
1531+ for indices , ( interp , sub_bcs ) in Isub .items ():
1532+ interp ._build_callable (bcs = sub_bcs )
15101533 blocks [indices ] = interp .callable ().handle
15111534 self .handle = PETSc .Mat ().createNest (blocks )
15121535
@@ -1516,10 +1539,10 @@ def callable() -> MixedInterpolator:
15161539 def callable () -> Function | Cofunction :
15171540 for k , sub_tensor in enumerate (f .subfunctions ):
15181541 sub_tensor .assign (sum (
1519- interp .assemble () for indices , interp in self . Isub .items () if indices [0 ] == k
1542+ interp .assemble (bcs = sub_bcs ) for indices , ( interp , sub_bcs ) in Isub .items () if indices [0 ] == k
15201543 ))
15211544 return f
15221545 else :
15231546 def callable () -> Number :
1524- return sum (interp .assemble () for interp in self . Isub .values ())
1547+ return sum (interp .assemble (bcs = sub_bcs ) for ( interp , sub_bcs ) in Isub .values ())
15251548 self .callable = callable
0 commit comments