3131from firedrake .ufl_expr import Argument , Coargument , action
3232from firedrake .cofunction import Cofunction
3333from firedrake .function import Function
34- from firedrake .mesh import MissingPointsBehaviour , VertexOnlyMeshMissingPointsError , VertexOnlyMeshTopology
34+ from firedrake .mesh import MissingPointsBehaviour , VertexOnlyMeshMissingPointsError , VertexOnlyMeshTopology , MeshGeometry
3535from firedrake .petsc import PETSc
3636from firedrake .halo import _get_mtype as get_dat_mpi_type
3737from firedrake .functionspaceimpl import WithGeometry
@@ -268,7 +268,7 @@ def _build_callable(self, output: Function | Cofunction | MatrixBase | None = No
268268 def assemble (
269269 self , tensor : Function | Cofunction | MatrixBase | None = None
270270 ) -> Function | Cofunction | MatrixBase | Number :
271- """Assemble the interpolation. The result depends on the rank (number of arguments)
271+ """Assemble the interpolation. The result depends on the rank (number of arguments)
272272 of the :class:`Interpolate` expression:
273273
274274 * rank-2: assemble the operator and return a matrix
@@ -279,7 +279,7 @@ def assemble(
279279 ----------
280280 tensor : Function | Cofunction | MatrixBase, optional
281281 Pre-allocated storage to receive the interpolated result. For rank-2
282- expressions this is expected to be a subclass of
282+ expressions this is expected to be a subclass of
283283 :class:`~firedrake.matrix.MatrixBase` whose
284284 ``petscmat`` will be populated. For lower-rank expressions this is
285285 a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`.
@@ -492,7 +492,7 @@ def callable():
492492 f_point_eval_input_ordering .dat .data_wo [:] = numpy .nan
493493
494494 def callable ():
495- assemble (action (self .point_eval_input_ordering , f_point_eval ),
495+ assemble (action (self .point_eval_input_ordering , f_point_eval ),
496496 tensor = f_point_eval_input_ordering )
497497
498498 # We assign these values to the output function
@@ -619,10 +619,8 @@ def _build_callable(self, output=None) -> None:
619619 for indices , sub_expr in expressions .items ():
620620 if isinstance (sub_expr , ufl .ZeroBaseForm ):
621621 continue
622- arguments = sub_expr .arguments ()
623- sub_space = sub_expr .argument_slots ()[0 ].function_space ().dual ()
624622 sub_tensor = tensor [indices [0 ]] if self .rank == 1 else tensor
625- loops .extend (build_interpolation_callables (sub_space , sub_tensor , sub_expr , self .subset , arguments , self .access , bcs = self .bcs ))
623+ loops .extend (build_interpolation_callables (sub_expr , sub_tensor , self .subset , self .access , bcs = self .bcs ))
626624
627625 if self .bcs and self .rank == 1 :
628626 loops .extend (partial (bc .apply , f ) for bc in self .bcs )
@@ -682,45 +680,49 @@ def callable():
682680 # Leave mat inside a callable so we can access the handle
683681 # property. If matfree is True, then the handle is a PETSc SF
684682 # pretending to be a PETSc Mat. If matfree is False, then this
685- # will be a PETSc Mat representing the equivalent permutation
686- # matrix
683+ # will be a PETSc Mat representing the equivalent permutation matrix
684+
687685 def callable ():
688686 return self .mat
689687
690688 self .callable = callable
691689
692690
693691@utils .known_pyop2_safe
694- def build_interpolation_callables (V : WithGeometry , tensor , expr , subset , arguments , access , bcs = None ) -> tuple [Callable , ...]:
695- """Builds callables to perform interpolation.
692+ def build_interpolation_callables (
693+ expr : ufl .Interpolate ,
694+ tensor : op2 .Dat | op2 .Mat | op2 .Global ,
695+ access : Literal [op2 .WRITE , op2 .MIN , op2 .MAX , op2 .INC ],
696+ subset : op2 .Subset | None = None ,
697+ bcs : Iterable [BCBase ] | None = None
698+ ) -> tuple [Callable , ...]:
699+ """Returns tuple of callables which calculate the interpolation.
696700
697701 Parameters
698702 ----------
699- V : WithGeometry
700- _description_
701- tensor : _type_
702- _description_
703- expr : _type_
704- _description_
705- subset : _type_
706- _description_
707- arguments : _type_
708- _description_
709- access : _type_
710- _description_
711- bcs : _type_, optional
712- _description_, by default None
703+ expr : ufl.Interpolate
704+ The symbolic interpolation expression.
705+ tensor : op2.Dat | op2.Mat | op2.Global
706+ Object to hold the result of the interpolation.
707+ access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC]
708+ op2 access descriptor
709+ subset : op2.Subset | None, optional
710+ An optional subset to apply the interpolation over, by default None.
711+ bcs : Iterable[BCBase] | None, optional
712+ An optional list of boundary conditions to zero-out in the
713+ output function space. Interpolator rows or columns which are
714+ associated with boundary condition nodes are zeroed out when this is
715+ specified. By default None, by default None.
713716
714717 Returns
715718 -------
716719 tuple[Callable, ...]
717- Tuple of callables
718-
719- """
720+ Tuple of callables which perform the interpolation.
721+ """
720722 if not isinstance (expr , ufl .Interpolate ):
721723 raise ValueError ("Expecting to interpolate a ufl.Interpolate" )
722724 dual_arg , operand = expr .argument_slots ()
723-
725+ V = dual_arg . function_space (). dual ()
724726 try :
725727 to_element = create_element (V .ufl_element ())
726728 except KeyError :
@@ -825,6 +827,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen
825827 copyin = ()
826828 copyout = ()
827829
830+ arguments = expr .arguments ()
828831 if isinstance (tensor , op2 .Global ):
829832 parloop_args .append (tensor (access ))
830833 elif isinstance (tensor , op2 .Dat ):
@@ -897,7 +900,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen
897900 return copyin + callables + (parloop_compute_callable , ) + copyout
898901
899902
900- def get_interp_node_map (source_mesh , target_mesh , fs ) :
903+ def get_interp_node_map (source_mesh : MeshGeometry , target_mesh : MeshGeometry , fs : WithGeometry ) -> op2 . Map | None :
901904 """Return the map between cells of the target mesh and nodes of the function space.
902905
903906 If the function space is defined on the source mesh then the node map is composed
@@ -999,7 +1002,7 @@ def rebuild_te(element, expr_cell, rt_var_name):
9991002 transpose = element ._transpose )
10001003
10011004
1002- def compose_map_and_cache (map1 , map2 ) :
1005+ def compose_map_and_cache (map1 : op2 . Map , map2 : op2 . Map | None ) -> op2 . ComposedMap | None :
10031006 """
10041007 Retrieve a :class:`pyop2.ComposedMap` map from the cache of map1
10051008 using map2 as the cache key. The composed map maps from the iterset
@@ -1022,7 +1025,7 @@ def compose_map_and_cache(map1, map2):
10221025 return cmap
10231026
10241027
1025- def vom_cell_parent_node_map_extruded (vertex_only_mesh , extruded_cell_node_map ) :
1028+ def vom_cell_parent_node_map_extruded (vertex_only_mesh : MeshGeometry , extruded_cell_node_map : op2 . Map ) -> op2 . Map :
10261029 """Build a map from the cells of a vertex only mesh to the nodes of the
10271030 nodes on the source mesh where the source mesh is extruded.
10281031
@@ -1210,13 +1213,25 @@ def mpi_type(self):
12101213 def mpi_type (self , val ):
12111214 self ._mpi_type = val
12121215
1213- def expr_as_coeff (self , source_vec = None ):
1214- """
1215- Return a coefficient that corresponds to the expression used at
1216+ def expr_as_coeff (self , source_vec : PETSc .Vec | None = None ) -> Function :
1217+ """Return a coefficient that corresponds to the expression used at
12161218 construction, where the expression has been interpolated into the P0DG
12171219 function space on the source vertex-only mesh.
12181220
12191221 Will fail if there are no arguments.
1222+
1223+ Parameters
1224+ ----------
1225+ source_vec : PETSc.Vec | None, optional
1226+ Optional vector used to replace arguments in the expression.
1227+ By default None.
1228+
1229+ Returns
1230+ -------
1231+ Function
1232+ A Function representing the expression as a coefficient on the
1233+ source vertex-only mesh.
1234+
12201235 """
12211236 # Since we always output a coefficient when we don't have arguments in
12221237 # the expression, we should evaluate the expression on the source mesh
@@ -1243,7 +1258,16 @@ def expr_as_coeff(self, source_vec=None):
12431258 coeff = firedrake .Function (P0DG ).interpolate (coeff_expr )
12441259 return coeff
12451260
1246- def reduce (self , source_vec , target_vec ):
1261+ def reduce (self , source_vec : PETSc .Vec , target_vec : PETSc .Vec ) -> None :
1262+ """Reduce data in source_vec using the PETSc SF.
1263+
1264+ Parameters
1265+ ----------
1266+ source_vec : PETSc.Vec
1267+ The vector to reduce.
1268+ target_vec : PETSc.Vec
1269+ The vector to store the result in.
1270+ """
12471271 source_arr = source_vec .getArray (readonly = True )
12481272 target_arr = target_vec .getArray ()
12491273 self .sf .reduceBegin (
@@ -1259,7 +1283,16 @@ def reduce(self, source_vec, target_vec):
12591283 MPI .REPLACE ,
12601284 )
12611285
1262- def broadcast (self , source_vec , target_vec ):
1286+ def broadcast (self , source_vec : PETSc .Vec , target_vec : PETSc .Vec ) -> None :
1287+ """Broadcast data in source_vec using the PETSc SF.
1288+
1289+ Parameters
1290+ ----------
1291+ source_vec : PETSc.Vec
1292+ The vector to broadcast.
1293+ target_vec : PETSc.Vec
1294+ The vector to store the result in.
1295+ """
12631296 source_arr = source_vec .getArray (readonly = True )
12641297 target_arr = target_vec .getArray ()
12651298 self .sf .bcastBegin (
@@ -1275,19 +1308,56 @@ def broadcast(self, source_vec, target_vec):
12751308 MPI .REPLACE ,
12761309 )
12771310
1278- def mult (self , mat , source_vec , target_vec ):
1279- # need to evaluate expression before doing mult
1311+ def mult (self , mat : PETSc .Mat , source_vec : PETSc .Vec , target_vec : PETSc .Vec ) -> None :
1312+ """Applies the interpolation operator.
1313+
1314+ Parameters
1315+ ----------
1316+ mat : PETSc.Mat
1317+ Required by petsc4py but unused.
1318+ source_vec : PETSc.Vec
1319+ The vector to interpolate.
1320+ target_vec : PETSc.Vec
1321+ The vector to store the result in.
1322+ """
1323+ # Need to convert the expression into a coefficient
1324+ # so that we can broadcast/reduce it
12801325 coeff = self .expr_as_coeff (source_vec )
12811326 with coeff .dat .vec_ro as coeff_vec :
12821327 if self .forward_reduce :
12831328 self .reduce (coeff_vec , target_vec )
12841329 else :
12851330 self .broadcast (coeff_vec , target_vec )
12861331
1287- def multHermitian (self , mat , source_vec , target_vec ):
1332+ def multHermitian (self , mat : PETSc .Mat , source_vec : PETSc .Vec , target_vec : PETSc .Vec ) -> None :
1333+ """Applies the adjoint of the interpolation operator.
1334+ Since ``VomOntoVomMat`` represents a permutation, it is
1335+ real-valued and thus the adjoint is the transpose.
1336+
1337+ Parameters
1338+ ----------
1339+ mat : PETSc.Mat
1340+ Required by petsc4py but unused.
1341+ source_vec : PETSc.Vec
1342+ The vector to adjoint interpolate.
1343+ target_vec : PETSc.Vec
1344+ The vector to store the result in.
1345+ """
12881346 self .multTranspose (mat , source_vec , target_vec )
12891347
1290- def multTranspose (self , mat , source_vec , target_vec ):
1348+ def multTranspose (self , mat : PETSc .Mat , source_vec : PETSc .Vec , target_vec : PETSc .Vec ) -> None :
1349+ """Applies the tranpose of the interpolation operator. Called by `self.multHermitian`.
1350+
1351+ Parameters
1352+ ----------
1353+ mat : PETSc.Mat
1354+ Required by petsc4py but unused.
1355+ source_vec : PETSc.Vec
1356+ The vector to transpose interpolate.
1357+ target_vec : PETSc.Vec
1358+ The vector to store the result in.
1359+
1360+ """
12911361 # can only do adjoint if our expression exclusively contains a
12921362 # single argument, making the application of the adjoint operator
12931363 # straightforward (haven't worked out how to do this otherwise!)
@@ -1314,9 +1384,17 @@ def multTranspose(self, mat, source_vec, target_vec):
13141384 target_vec .zeroEntries ()
13151385 self .reduce (source_vec , target_vec )
13161386
1317- def _create_permutation_mat (self ):
1387+ def _create_permutation_mat (self ) -> PETSc . Mat :
13181388 """Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to
1319- its input ordering vertex-only mesh"""
1389+ its input ordering vertex-only mesh.
1390+
1391+ Returns
1392+ -------
1393+ PETSc.Mat
1394+ PETSc seqaij matrix
1395+ """
1396+ # To create the permutation matrix we broadcast an array of indices contiguous across
1397+ # all ranks and then use these indices to set the values of the matrix directly.
13201398 mat = PETSc .Mat ().createAIJ ((self .target_size , self .source_size ), nnz = 1 , comm = self .V .comm )
13211399 mat .setUp ()
13221400 start = sum (self ._local_sizes [:self .V .comm .rank ])
@@ -1333,7 +1411,14 @@ def _create_permutation_mat(self):
13331411 mat .transpose ()
13341412 return mat
13351413
1336- def _wrap_python_mat (self ):
1414+ def _wrap_python_mat (self ) -> PETSc .Mat :
1415+ """Wraps this object as a PETSc Mat. Used for matfree interpolation.
1416+
1417+ Returns
1418+ -------
1419+ PETSc.Mat
1420+ A PETSc Mat of type python with this object as its context.
1421+ """
13371422 mat = PETSc .Mat ().create (comm = self .V .comm )
13381423 if self .forward_reduce :
13391424 mat_size = (self .source_size , self .target_size )
@@ -1401,7 +1486,6 @@ def __iter__(self):
14011486 return iter (self ._sub_interpolators )
14021487
14031488 def _build_callable (self , output = None ):
1404- """Assemble the operator."""
14051489 V_dest = self .expr .function_space () or self .target_space
14061490 f = output or Function (V_dest )
14071491 if self .rank == 2 :
@@ -1420,6 +1504,7 @@ def callable():
14201504 return f
14211505 else :
14221506 assert self .rank == 0
1507+
14231508 def callable ():
14241509 result = sum (self [i ].assemble () for i in self )
14251510 assert isinstance (result , Number )
0 commit comments