Skip to content

Commit 93476e4

Browse files
committed
suggestions
1 parent b9ab2e0 commit 93476e4

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

firedrake/bcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# A module implementing strong (Dirichlet) boundary conditions.
22
import numpy as np
33

4-
import functools
4+
from functools import partial, reduce
55
import itertools
66

77
import ufl
@@ -167,7 +167,7 @@ def hermite_stride(bcnodes):
167167
# Edge conditions have only been tested with Lagrange elements.
168168
# Need to expand the list.
169169
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
170-
bcnodes1 = functools.reduce(np.intersect1d, bcnodes1)
170+
bcnodes1 = reduce(np.intersect1d, bcnodes1)
171171
bcnodes.append(bcnodes1)
172172
return np.concatenate(bcnodes)
173173

@@ -363,7 +363,7 @@ def function_arg(self, g):
363363
interpolator = get_interpolator(firedrake.interpolate(g, V))
364364
# Call this here to check if the element supports interpolation
365365
interpolator._build_callable()
366-
self._function_arg_update = lambda: interpolator.assemble(tensor=self._function_arg)
366+
self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg)
367367
except (ValueError, NotImplementedError):
368368
# Element doesn't implement interpolation
369369
self._function_arg = firedrake.Function(V).project(g)

firedrake/interpolation.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,7 @@ def assemble(
305305
return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res)
306306
else:
307307
assert isinstance(tensor, Function | Cofunction | None)
308-
if tensor and isinstance(result, Function | Cofunction):
309-
tensor.assign(result)
310-
return tensor
311-
return result
308+
return tensor.assign(result) if tensor else result
312309

313310

314311
class DofNotDefinedError(Exception):
@@ -454,13 +451,15 @@ def _build_callable(self, output=None):
454451
else:
455452
symbolic = action(self.point_eval_input_ordering, self.point_eval)
456453
self.handle = assemble(symbolic).petscmat
457-
self.callable = lambda: self
454+
def callable() -> CrossMeshInterpolator:
455+
return self
458456
else:
459457
if self.expr.is_adjoint:
460458
assert self.rank == 1
461459
# f_src is a cofunction on V_dest.dual
462460
cofunc = self.dual_arg
463461
assert isinstance(cofunc, Cofunction)
462+
464463
# Our first adjoint operation is to assign the dat values to a
465464
# P0DG cofunction on our input ordering VOM.
466465
f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual())
@@ -471,7 +470,7 @@ def _build_callable(self, output=None):
471470
# We don't worry about skipping over missing points here
472471
# because we're going from the input ordering VOM to the original VOM
473472
# and all points from the input ordering VOM are in the original.
474-
def callable():
473+
def callable() -> Cofunction:
475474
f_src_at_src_node_coords = assemble(action(self.point_eval_input_ordering, f_input_ordering))
476475
assemble(action(self.point_eval, f_src_at_src_node_coords), tensor=f)
477476
return f
@@ -491,7 +490,7 @@ def callable():
491490
# them later.
492491
f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan
493492

494-
def callable():
493+
def callable() -> Function | Number:
495494
assemble(action(self.point_eval_input_ordering, f_point_eval),
496495
tensor=f_point_eval_input_ordering)
497496

@@ -508,7 +507,7 @@ def callable():
508507
return assemble(action(self.dual_arg, f))
509508
else:
510509
return f
511-
self.callable = callable
510+
self.callable = callable
512511

513512

514513
class SameMeshInterpolator(Interpolator):
@@ -645,7 +644,6 @@ def _build_callable(self, output=None):
645644
tensor = None
646645
else:
647646
f = output or self._get_tensor()
648-
assert isinstance(f, Function | Cofunction)
649647
tensor = f.dat
650648
# NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
651649
# data, including the correct data size and dimensional information
@@ -656,13 +654,16 @@ def _build_callable(self, output=None):
656654
self.mat.mpi_type = get_dat_mpi_type(f.dat)[0]
657655
if self.expr.is_adjoint:
658656
assert isinstance(self.dual_arg, ufl.Cofunction)
657+
assert isinstance(f, Cofunction)
659658

660-
def callable():
659+
def callable() -> Cofunction:
661660
with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec:
662661
self.mat.handle.multHermitian(source_vec, target_vec)
663662
return f
664663
else:
665-
def callable():
664+
assert isinstance(f, Function)
665+
666+
def callable() -> Function:
666667
coeff = self.mat.expr_as_coeff()
667668
with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec:
668669
self.mat.handle.mult(coeff_vec, target_vec)
@@ -682,7 +683,7 @@ def callable():
682683
# pretending to be a PETSc Mat. If matfree is False, then this
683684
# will be a PETSc Mat representing the equivalent permutation matrix
684685

685-
def callable():
686+
def callable() -> VomOntoVomMat:
686687
return self.mat
687688

688689
self.callable = callable
@@ -1494,19 +1495,15 @@ def _build_callable(self, output=None):
14941495
for i in self:
14951496
self[i]._build_callable()
14961497
blocks[i] = self[i].callable().handle
1497-
petscmat = PETSc.Mat().createNest(blocks)
1498-
tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat)
1499-
callable = lambda: tensor.M
1498+
self.handle = PETSc.Mat().createNest(blocks)
1499+
def callable() -> MixedInterpolator:
1500+
return self
15001501
elif self.rank == 1:
1501-
def callable():
1502+
def callable() -> Function | Cofunction:
15021503
for k, sub_tensor in enumerate(f.subfunctions):
15031504
sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k))
15041505
return f
15051506
else:
1506-
assert self.rank == 0
1507-
1508-
def callable():
1509-
result = sum(self[i].assemble() for i in self)
1510-
assert isinstance(result, Number)
1511-
return result
1507+
def callable() -> Number:
1508+
return sum(self[i].assemble() for i in self)
15121509
self.callable = callable

0 commit comments

Comments
 (0)