@@ -305,10 +305,7 @@ def assemble(
305305 return tensor or firedrake .AssembledMatrix (self .expr_args , self .bcs , res )
306306 else :
307307 assert isinstance (tensor , Function | Cofunction | None )
308- if tensor and isinstance (result , Function | Cofunction ):
309- tensor .assign (result )
310- return tensor
311- return result
308+ return tensor .assign (result ) if tensor else result
312309
313310
314311class DofNotDefinedError (Exception ):
@@ -454,13 +451,15 @@ def _build_callable(self, output=None):
454451 else :
455452 symbolic = action (self .point_eval_input_ordering , self .point_eval )
456453 self .handle = assemble (symbolic ).petscmat
457- self .callable = lambda : self
454+ def callable () -> CrossMeshInterpolator :
455+ return self
458456 else :
459457 if self .expr .is_adjoint :
460458 assert self .rank == 1
461459 # f_src is a cofunction on V_dest.dual
462460 cofunc = self .dual_arg
463461 assert isinstance (cofunc , Cofunction )
462+
464463 # Our first adjoint operation is to assign the dat values to a
465464 # P0DG cofunction on our input ordering VOM.
466465 f_input_ordering = Cofunction (self .P0DG_vom_input_ordering .dual ())
@@ -471,7 +470,7 @@ def _build_callable(self, output=None):
471470 # We don't worry about skipping over missing points here
472471 # because we're going from the input ordering VOM to the original VOM
473472 # and all points from the input ordering VOM are in the original.
474- def callable ():
473+ def callable () -> Cofunction :
475474 f_src_at_src_node_coords = assemble (action (self .point_eval_input_ordering , f_input_ordering ))
476475 assemble (action (self .point_eval , f_src_at_src_node_coords ), tensor = f )
477476 return f
@@ -491,7 +490,7 @@ def callable():
491490 # them later.
492491 f_point_eval_input_ordering .dat .data_wo [:] = numpy .nan
493492
494- def callable ():
493+ def callable () -> Function | Number :
495494 assemble (action (self .point_eval_input_ordering , f_point_eval ),
496495 tensor = f_point_eval_input_ordering )
497496
@@ -508,7 +507,7 @@ def callable():
508507 return assemble (action (self .dual_arg , f ))
509508 else :
510509 return f
511- self .callable = callable
510+ self .callable = callable
512511
513512
514513class SameMeshInterpolator (Interpolator ):
@@ -645,7 +644,6 @@ def _build_callable(self, output=None):
645644 tensor = None
646645 else :
647646 f = output or self ._get_tensor ()
648- assert isinstance (f , Function | Cofunction )
649647 tensor = f .dat
650648 # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
651649 # data, including the correct data size and dimensional information
@@ -656,13 +654,16 @@ def _build_callable(self, output=None):
656654 self .mat .mpi_type = get_dat_mpi_type (f .dat )[0 ]
657655 if self .expr .is_adjoint :
658656 assert isinstance (self .dual_arg , ufl .Cofunction )
657+ assert isinstance (f , Cofunction )
659658
660- def callable ():
659+ def callable () -> Cofunction :
661660 with self .dual_arg .dat .vec_ro as source_vec , f .dat .vec_wo as target_vec :
662661 self .mat .handle .multHermitian (source_vec , target_vec )
663662 return f
664663 else :
665- def callable ():
664+ assert isinstance (f , Function )
665+
666+ def callable () -> Function :
666667 coeff = self .mat .expr_as_coeff ()
667668 with coeff .dat .vec_ro as coeff_vec , f .dat .vec_wo as target_vec :
668669 self .mat .handle .mult (coeff_vec , target_vec )
@@ -682,7 +683,7 @@ def callable():
682683 # pretending to be a PETSc Mat. If matfree is False, then this
683684 # will be a PETSc Mat representing the equivalent permutation matrix
684685
685- def callable ():
686+ def callable () -> VomOntoVomMat :
686687 return self .mat
687688
688689 self .callable = callable
@@ -1494,19 +1495,15 @@ def _build_callable(self, output=None):
14941495 for i in self :
14951496 self [i ]._build_callable ()
14961497 blocks [i ] = self [i ].callable ().handle
1497- petscmat = PETSc .Mat ().createNest (blocks )
1498- tensor = firedrake . AssembledMatrix ( self . expr_args , self . bcs , petscmat )
1499- callable = lambda : tensor . M
1498+ self . handle = PETSc .Mat ().createNest (blocks )
1499+ def callable () -> MixedInterpolator :
1500+ return self
15001501 elif self .rank == 1 :
1501- def callable ():
1502+ def callable () -> Function | Cofunction :
15021503 for k , sub_tensor in enumerate (f .subfunctions ):
15031504 sub_tensor .assign (sum (self [i ].assemble () for i in self if i [0 ] == k ))
15041505 return f
15051506 else :
1506- assert self .rank == 0
1507-
1508- def callable ():
1509- result = sum (self [i ].assemble () for i in self )
1510- assert isinstance (result , Number )
1511- return result
1507+ def callable () -> Number :
1508+ return sum (self [i ].assemble () for i in self )
15121509 self .callable = callable
0 commit comments