@@ -87,12 +87,18 @@ class InterpolateOptions:
8787 If ``False``, then construct the permutation matrix for interpolating
8888 between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
8989 and reduce operations.
90+ bcs : Iterable[BCBase] | None, optional
91+ An optional list of boundary conditions to zero-out in the
92+ output function space. Interpolator rows or columns which are
93+ associated with boundary condition nodes are zeroed out when this is
94+ specified. By default None.
9095 """
9196 subset : op2 .Subset | None = None
9297 access : Literal [op2 .WRITE , op2 .MIN , op2 .MAX , op2 .INC ] | None = None
9398 allow_missing_dofs : bool = False
9499 default_missing_val : float | None = None
95100 matfree : bool = True
101+ bcs : Iterable [BCBase ] | None = None
96102
97103
98104class Interpolate (ufl .Interpolate ):
@@ -162,18 +168,13 @@ def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs) -> Interpo
162168 return Interpolate (expr , V , ** kwargs )
163169
164170
165- def get_interpolator (expr : Interpolate , bcs : Iterable [ BCBase ] | None = None ) -> "Interpolator" :
171+ def get_interpolator (expr : Interpolate ) -> "Interpolator" :
166172 """Create an Interpolator.
167173
168174 Parameters
169175 ----------
170176 expr : Interpolate
171177 Symbolic interpolation expression.
172- bcs : Iterable[BCBase] | None, optional
173- An optional list of boundary conditions to zero-out in the
174- output function space. Interpolator rows or columns which are
175- associated with boundary condition nodes are zeroed out when this is
176- specified. By default None.
177178
178179 Returns
179180 -------
@@ -183,7 +184,7 @@ def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) ->
183184 arguments = expr .arguments ()
184185 has_mixed_arguments = any (len (arg .function_space ()) > 1 for arg in arguments )
185186 if len (arguments ) == 2 and has_mixed_arguments :
186- return MixedInterpolator (expr , bcs = bcs )
187+ return MixedInterpolator (expr )
187188
188189 operand , = expr .ufl_operands
189190 target_mesh = expr .target_space .mesh ()
@@ -194,24 +195,24 @@ def get_interpolator(expr: Interpolate, bcs: Iterable[BCBase] | None = None) ->
194195 and target_mesh .topological_dimension () == source_mesh .topological_dimension ()
195196 )
196197 if target_mesh is source_mesh or submesh_interp_implemented :
197- return SameMeshInterpolator (expr , bcs = bcs )
198+ return SameMeshInterpolator (expr )
198199
199200 target_topology = target_mesh .topology
200201 source_topology = source_mesh .topology
201202
202203 if isinstance (target_topology , VertexOnlyMeshTopology ):
203204 if isinstance (source_topology , VertexOnlyMeshTopology ):
204- return VomOntoVomInterpolator (expr , bcs = bcs )
205+ return VomOntoVomInterpolator (expr )
205206 if target_mesh .geometric_dimension () != source_mesh .geometric_dimension ():
206207 raise ValueError ("Cannot interpolate onto a mesh of a different geometric dimension" )
207208 if not hasattr (target_mesh , "_parent_mesh" ) or target_mesh ._parent_mesh is not source_mesh :
208209 raise ValueError ("Can only interpolate across meshes where the source mesh is the parent of the target" )
209- return SameMeshInterpolator (expr , bcs = bcs )
210+ return SameMeshInterpolator (expr )
210211
211212 if has_mixed_arguments or len (expr .target_space ) > 1 :
212- return MixedInterpolator (expr , bcs = bcs )
213+ return MixedInterpolator (expr )
213214
214- return CrossMeshInterpolator (expr , bcs = bcs )
215+ return CrossMeshInterpolator (expr )
215216
216217
217218class Interpolator (abc .ABC ):
@@ -222,14 +223,9 @@ class Interpolator(abc.ABC):
222223 ----------
223224 expr : Interpolate
224225 The symbolic interpolation expression.
225- bcs : Iterable[BCBase], optional
226- An optional list of boundary conditions to zero-out in the
227- output function space. Interpolator rows or columns which are
228- associated with boundary condition nodes are zeroed out when this is
229- specified. By default None.
230226
231227 """
232- def __init__ (self , expr : Interpolate , bcs : Iterable [ BCBase ] | None = None ):
228+ def __init__ (self , expr : Interpolate ):
233229 dual_arg , operand = expr .argument_slots ()
234230 self .expr = expr
235231 self .expr_args = expr .arguments ()
@@ -245,7 +241,7 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None):
245241 self .allow_missing_dofs = expr .options .allow_missing_dofs
246242 self .default_missing_val = expr .options .default_missing_val
247243 self .matfree = expr .options .matfree
248- self .bcs = bcs
244+ self .bcs = expr . options . bcs
249245 self .callable = None
250246 self .access = expr .options .access
251247
@@ -345,8 +341,8 @@ class CrossMeshInterpolator(Interpolator):
345341 """
346342
347343 @no_annotations
348- def __init__ (self , expr : Interpolate , bcs : Iterable [ BCBase ] | None = None ):
349- super ().__init__ (expr , bcs )
344+ def __init__ (self , expr : Interpolate ):
345+ super ().__init__ (expr )
350346 if self .access and self .access != op2 .WRITE :
351347 raise NotImplementedError (
352348 "Access other than op2.WRITE not implemented for cross-mesh interpolation."
@@ -520,8 +516,8 @@ class SameMeshInterpolator(Interpolator):
520516 """
521517
522518 @no_annotations
523- def __init__ (self , expr , bcs = None ):
524- super ().__init__ (expr , bcs = bcs )
519+ def __init__ (self , expr ):
520+ super ().__init__ (expr )
525521 subset = self .subset
526522 if subset is None :
527523 target = self .target_mesh .topology
@@ -594,9 +590,6 @@ def _build_callable(self, output=None) -> None:
594590
595591 loops = []
596592
597- if self .access == op2 .INC :
598- loops .append (tensor .zero )
599-
600593 # Arguments in the operand are allowed to be from a MixedFunctionSpace
601594 # We need to split the target space V and generate separate kernels
602595 if self .rank == 2 :
@@ -635,8 +628,8 @@ def callable(loops, f):
635628
636629class VomOntoVomInterpolator (SameMeshInterpolator ):
637630
638- def __init__ (self , expr : Interpolate , bcs = None ):
639- super ().__init__ (expr , bcs = bcs )
631+ def __init__ (self , expr : Interpolate ):
632+ super ().__init__ (expr )
640633
641634 def _build_callable (self , output = None ):
642635 self .mat = VomOntoVomMat (self )
@@ -899,7 +892,10 @@ def _build_interpolation_callables(
899892 if isinstance (tensor , op2 .Mat ):
900893 return parloop_compute_callable , tensor .assemble
901894 else :
902- return copyin + callables + (parloop_compute_callable , ) + copyout
895+ extra = copyin + callables
896+ if access == op2 .INC :
897+ extra += (tensor .zero ,)
898+ return extra + (parloop_compute_callable , ) + copyout
903899
904900
905901def get_interp_node_map (source_mesh : MeshGeometry , target_mesh : MeshGeometry , fs : WithGeometry ) -> op2 .Map | None :
@@ -1446,11 +1442,9 @@ class MixedInterpolator(Interpolator):
14461442 V
14471443 The :class:`.FunctionSpace` or :class:`.Function` to
14481444 interpolate into.
1449- bcs
1450- A list of boundary conditions.
14511445 """
1452- def __init__ (self , expr , bcs = None ):
1453- super ().__init__ (expr , bcs = bcs )
1446+ def __init__ (self , expr ):
1447+ super ().__init__ (expr )
14541448
14551449 # We need a Coargument in order to split the Interpolate
14561450 needs_action = not any (isinstance (a , Coargument ) for a in self .expr_args )
@@ -1467,17 +1461,17 @@ def __init__(self, expr, bcs=None):
14671461 continue
14681462 vi , _ = form .argument_slots ()
14691463 Vtarget = vi .function_space ().dual ()
1470- if bcs and self .rank != 0 :
1464+ if self . bcs and self .rank != 0 :
14711465 args = form .arguments ()
14721466 Vsource = args [1 - vi .number ()].function_space ()
1473- sub_bcs = [bc for bc in bcs if bc .function_space () in {Vsource , Vtarget }]
1467+ sub_bcs = [bc for bc in self . bcs if bc .function_space () in {Vsource , Vtarget }]
14741468 else :
14751469 sub_bcs = None
14761470 if needs_action :
14771471 # Take the action of each sub-cofunction against each block
14781472 form = action (form , dual_split [indices [- 1 :]])
1479-
1480- Isub [indices ] = get_interpolator (form , bcs = sub_bcs )
1473+ form . options . bcs = sub_bcs
1474+ Isub [indices ] = get_interpolator (form )
14811475
14821476 self ._sub_interpolators = Isub
14831477
0 commit comments