Skip to content

Commit 139bee9

Browse files
committed
lint; type hints and docstrings
1 parent 5263aff commit 139bee9

File tree

2 files changed

+132
-48
lines changed

2 files changed

+132
-48
lines changed

firedrake/bcs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import ufl
88
from ufl import as_ufl, as_tensor
9-
from finat.ufl import VectorElement, EnrichedElement
10-
from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement
9+
from finat.ufl import VectorElement
1110
import finat
1211

1312
import pyop2 as op2
@@ -364,7 +363,7 @@ def function_arg(self, g):
364363
try:
365364
self._function_arg = firedrake.Function(V)
366365
interpolator = get_interpolator(firedrake.interpolate(g, V))
367-
# Call this here to check if the element supports interpolation
366+
# Call this here to check if the element supports interpolation
368367
interpolator._build_callable()
369368
self._function_arg_update = lambda: interpolator.assemble(tensor=self._function_arg)
370369
except (ValueError, NotImplementedError):

firedrake/interpolation.py

Lines changed: 130 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from firedrake.ufl_expr import Argument, Coargument, action
3232
from firedrake.cofunction import Cofunction
3333
from firedrake.function import Function
34-
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology
34+
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry
3535
from firedrake.petsc import PETSc
3636
from firedrake.halo import _get_mtype as get_dat_mpi_type
3737
from firedrake.functionspaceimpl import WithGeometry
@@ -268,7 +268,7 @@ def _build_callable(self, output: Function | Cofunction | MatrixBase | None = No
268268
def assemble(
269269
self, tensor: Function | Cofunction | MatrixBase | None = None
270270
) -> Function | Cofunction | MatrixBase | Number:
271-
"""Assemble the interpolation. The result depends on the rank (number of arguments)
271+
"""Assemble the interpolation. The result depends on the rank (number of arguments)
272272
of the :class:`Interpolate` expression:
273273
274274
* rank-2: assemble the operator and return a matrix
@@ -279,7 +279,7 @@ def assemble(
279279
----------
280280
tensor : Function | Cofunction | MatrixBase, optional
281281
Pre-allocated storage to receive the interpolated result. For rank-2
282-
expressions this is expected to be a subclass of
282+
expressions this is expected to be a subclass of
283283
:class:`~firedrake.matrix.MatrixBase` whose
284284
``petscmat`` will be populated. For lower-rank expressions this is
285285
a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`.
@@ -492,7 +492,7 @@ def callable():
492492
f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan
493493

494494
def callable():
495-
assemble(action(self.point_eval_input_ordering, f_point_eval),
495+
assemble(action(self.point_eval_input_ordering, f_point_eval),
496496
tensor=f_point_eval_input_ordering)
497497

498498
# We assign these values to the output function
@@ -619,10 +619,8 @@ def _build_callable(self, output=None) -> None:
619619
for indices, sub_expr in expressions.items():
620620
if isinstance(sub_expr, ufl.ZeroBaseForm):
621621
continue
622-
arguments = sub_expr.arguments()
623-
sub_space = sub_expr.argument_slots()[0].function_space().dual()
624622
sub_tensor = tensor[indices[0]] if self.rank == 1 else tensor
625-
loops.extend(build_interpolation_callables(sub_space, sub_tensor, sub_expr, self.subset, arguments, self.access, bcs=self.bcs))
623+
loops.extend(build_interpolation_callables(sub_expr, sub_tensor, self.subset, self.access, bcs=self.bcs))
626624

627625
if self.bcs and self.rank == 1:
628626
loops.extend(partial(bc.apply, f) for bc in self.bcs)
@@ -682,45 +680,49 @@ def callable():
682680
# Leave mat inside a callable so we can access the handle
683681
# property. If matfree is True, then the handle is a PETSc SF
684682
# pretending to be a PETSc Mat. If matfree is False, then this
685-
# will be a PETSc Mat representing the equivalent permutation
686-
# matrix
683+
# will be a PETSc Mat representing the equivalent permutation matrix
684+
687685
def callable():
688686
return self.mat
689687

690688
self.callable = callable
691689

692690

693691
@utils.known_pyop2_safe
694-
def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, arguments, access, bcs=None) -> tuple[Callable, ...]:
695-
"""Builds callables to perform interpolation.
692+
def build_interpolation_callables(
693+
expr: ufl.Interpolate,
694+
tensor: op2.Dat | op2.Mat | op2.Global,
695+
access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC],
696+
subset: op2.Subset | None = None,
697+
bcs: Iterable[BCBase] | None = None
698+
) -> tuple[Callable, ...]:
699+
"""Returns tuple of callables which calculate the interpolation.
696700
697701
Parameters
698702
----------
699-
V : WithGeometry
700-
_description_
701-
tensor : _type_
702-
_description_
703-
expr : _type_
704-
_description_
705-
subset : _type_
706-
_description_
707-
arguments : _type_
708-
_description_
709-
access : _type_
710-
_description_
711-
bcs : _type_, optional
712-
_description_, by default None
703+
expr : ufl.Interpolate
704+
The symbolic interpolation expression.
705+
tensor : op2.Dat | op2.Mat | op2.Global
706+
Object to hold the result of the interpolation.
707+
access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC]
708+
op2 access descriptor
709+
subset : op2.Subset | None, optional
710+
An optional subset to apply the interpolation over, by default None.
711+
bcs : Iterable[BCBase] | None, optional
712+
An optional list of boundary conditions to zero-out in the
713+
output function space. Interpolator rows or columns which are
714+
associated with boundary condition nodes are zeroed out when this is
715+
specified. By default None, by default None.
713716
714717
Returns
715718
-------
716719
tuple[Callable, ...]
717-
Tuple of callables
718-
719-
"""
720+
Tuple of callables which perform the interpolation.
721+
"""
720722
if not isinstance(expr, ufl.Interpolate):
721723
raise ValueError("Expecting to interpolate a ufl.Interpolate")
722724
dual_arg, operand = expr.argument_slots()
723-
725+
V = dual_arg.function_space().dual()
724726
try:
725727
to_element = create_element(V.ufl_element())
726728
except KeyError:
@@ -825,6 +827,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen
825827
copyin = ()
826828
copyout = ()
827829

830+
arguments = expr.arguments()
828831
if isinstance(tensor, op2.Global):
829832
parloop_args.append(tensor(access))
830833
elif isinstance(tensor, op2.Dat):
@@ -897,7 +900,7 @@ def build_interpolation_callables(V: WithGeometry, tensor, expr, subset, argumen
897900
return copyin + callables + (parloop_compute_callable, ) + copyout
898901

899902

900-
def get_interp_node_map(source_mesh, target_mesh, fs):
903+
def get_interp_node_map(source_mesh: MeshGeometry, target_mesh: MeshGeometry, fs: WithGeometry) -> op2.Map | None:
901904
"""Return the map between cells of the target mesh and nodes of the function space.
902905
903906
If the function space is defined on the source mesh then the node map is composed
@@ -999,7 +1002,7 @@ def rebuild_te(element, expr_cell, rt_var_name):
9991002
transpose=element._transpose)
10001003

10011004

1002-
def compose_map_and_cache(map1, map2):
1005+
def compose_map_and_cache(map1: op2.Map, map2: op2.Map | None) -> op2.ComposedMap | None:
10031006
"""
10041007
Retrieve a :class:`pyop2.ComposedMap` map from the cache of map1
10051008
using map2 as the cache key. The composed map maps from the iterset
@@ -1022,7 +1025,7 @@ def compose_map_and_cache(map1, map2):
10221025
return cmap
10231026

10241027

1025-
def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map):
1028+
def vom_cell_parent_node_map_extruded(vertex_only_mesh: MeshGeometry, extruded_cell_node_map: op2.Map) -> op2.Map:
10261029
"""Build a map from the cells of a vertex only mesh to the nodes of the
10271030
nodes on the source mesh where the source mesh is extruded.
10281031
@@ -1210,13 +1213,25 @@ def mpi_type(self):
12101213
def mpi_type(self, val):
12111214
self._mpi_type = val
12121215

1213-
def expr_as_coeff(self, source_vec=None):
1214-
"""
1215-
Return a coefficient that corresponds to the expression used at
1216+
def expr_as_coeff(self, source_vec: PETSc.Vec | None = None) -> Function:
1217+
"""Return a coefficient that corresponds to the expression used at
12161218
construction, where the expression has been interpolated into the P0DG
12171219
function space on the source vertex-only mesh.
12181220
12191221
Will fail if there are no arguments.
1222+
1223+
Parameters
1224+
----------
1225+
source_vec : PETSc.Vec | None, optional
1226+
Optional vector used to replace arguments in the expression.
1227+
By default None.
1228+
1229+
Returns
1230+
-------
1231+
Function
1232+
A Function representing the expression as a coefficient on the
1233+
source vertex-only mesh.
1234+
12201235
"""
12211236
# Since we always output a coefficient when we don't have arguments in
12221237
# the expression, we should evaluate the expression on the source mesh
@@ -1243,7 +1258,16 @@ def expr_as_coeff(self, source_vec=None):
12431258
coeff = firedrake.Function(P0DG).interpolate(coeff_expr)
12441259
return coeff
12451260

1246-
def reduce(self, source_vec, target_vec):
1261+
def reduce(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None:
1262+
"""Reduce data in source_vec using the PETSc SF.
1263+
1264+
Parameters
1265+
----------
1266+
source_vec : PETSc.Vec
1267+
The vector to reduce.
1268+
target_vec : PETSc.Vec
1269+
The vector to store the result in.
1270+
"""
12471271
source_arr = source_vec.getArray(readonly=True)
12481272
target_arr = target_vec.getArray()
12491273
self.sf.reduceBegin(
@@ -1259,7 +1283,16 @@ def reduce(self, source_vec, target_vec):
12591283
MPI.REPLACE,
12601284
)
12611285

1262-
def broadcast(self, source_vec, target_vec):
1286+
def broadcast(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None:
1287+
"""Broadcast data in source_vec using the PETSc SF.
1288+
1289+
Parameters
1290+
----------
1291+
source_vec : PETSc.Vec
1292+
The vector to broadcast.
1293+
target_vec : PETSc.Vec
1294+
The vector to store the result in.
1295+
"""
12631296
source_arr = source_vec.getArray(readonly=True)
12641297
target_arr = target_vec.getArray()
12651298
self.sf.bcastBegin(
@@ -1275,19 +1308,56 @@ def broadcast(self, source_vec, target_vec):
12751308
MPI.REPLACE,
12761309
)
12771310

1278-
def mult(self, mat, source_vec, target_vec):
1279-
# need to evaluate expression before doing mult
1311+
def mult(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None:
1312+
"""Applies the interpolation operator.
1313+
1314+
Parameters
1315+
----------
1316+
mat : PETSc.Mat
1317+
Required by petsc4py but unused.
1318+
source_vec : PETSc.Vec
1319+
The vector to interpolate.
1320+
target_vec : PETSc.Vec
1321+
The vector to store the result in.
1322+
"""
1323+
# Need to convert the expression into a coefficient
1324+
# so that we can broadcast/reduce it
12801325
coeff = self.expr_as_coeff(source_vec)
12811326
with coeff.dat.vec_ro as coeff_vec:
12821327
if self.forward_reduce:
12831328
self.reduce(coeff_vec, target_vec)
12841329
else:
12851330
self.broadcast(coeff_vec, target_vec)
12861331

1287-
def multHermitian(self, mat, source_vec, target_vec):
1332+
def multHermitian(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None:
1333+
"""Applies the adjoint of the interpolation operator.
1334+
Since ``VomOntoVomMat`` represents a permutation, it is
1335+
real-valued and thus the adjoint is the transpose.
1336+
1337+
Parameters
1338+
----------
1339+
mat : PETSc.Mat
1340+
Required by petsc4py but unused.
1341+
source_vec : PETSc.Vec
1342+
The vector to adjoint interpolate.
1343+
target_vec : PETSc.Vec
1344+
The vector to store the result in.
1345+
"""
12881346
self.multTranspose(mat, source_vec, target_vec)
12891347

1290-
def multTranspose(self, mat, source_vec, target_vec):
1348+
def multTranspose(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None:
1349+
"""Applies the tranpose of the interpolation operator. Called by `self.multHermitian`.
1350+
1351+
Parameters
1352+
----------
1353+
mat : PETSc.Mat
1354+
Required by petsc4py but unused.
1355+
source_vec : PETSc.Vec
1356+
The vector to transpose interpolate.
1357+
target_vec : PETSc.Vec
1358+
The vector to store the result in.
1359+
1360+
"""
12911361
# can only do adjoint if our expression exclusively contains a
12921362
# single argument, making the application of the adjoint operator
12931363
# straightforward (haven't worked out how to do this otherwise!)
@@ -1314,9 +1384,17 @@ def multTranspose(self, mat, source_vec, target_vec):
13141384
target_vec.zeroEntries()
13151385
self.reduce(source_vec, target_vec)
13161386

1317-
def _create_permutation_mat(self):
1387+
def _create_permutation_mat(self) -> PETSc.Mat:
13181388
"""Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to
1319-
its input ordering vertex-only mesh"""
1389+
its input ordering vertex-only mesh.
1390+
1391+
Returns
1392+
-------
1393+
PETSc.Mat
1394+
PETSc seqaij matrix
1395+
"""
1396+
# To create the permutation matrix we broadcast an array of indices contiguous across
1397+
# all ranks and then use these indices to set the values of the matrix directly.
13201398
mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm)
13211399
mat.setUp()
13221400
start = sum(self._local_sizes[:self.V.comm.rank])
@@ -1333,7 +1411,14 @@ def _create_permutation_mat(self):
13331411
mat.transpose()
13341412
return mat
13351413

1336-
def _wrap_python_mat(self):
1414+
def _wrap_python_mat(self) -> PETSc.Mat:
1415+
"""Wraps this object as a PETSc Mat. Used for matfree interpolation.
1416+
1417+
Returns
1418+
-------
1419+
PETSc.Mat
1420+
A PETSc Mat of type python with this object as its context.
1421+
"""
13371422
mat = PETSc.Mat().create(comm=self.V.comm)
13381423
if self.forward_reduce:
13391424
mat_size = (self.source_size, self.target_size)
@@ -1401,7 +1486,6 @@ def __iter__(self):
14011486
return iter(self._sub_interpolators)
14021487

14031488
def _build_callable(self, output=None):
1404-
"""Assemble the operator."""
14051489
V_dest = self.expr.function_space() or self.target_space
14061490
f = output or Function(V_dest)
14071491
if self.rank == 2:
@@ -1420,6 +1504,7 @@ def callable():
14201504
return f
14211505
else:
14221506
assert self.rank == 0
1507+
14231508
def callable():
14241509
result = sum(self[i].assemble() for i in self)
14251510
assert isinstance(result, Number)

0 commit comments

Comments
 (0)