@@ -262,8 +262,9 @@ class Interpolator(abc.ABC):
262262 """
263263
264264 def __new__ (cls , expr , V , ** kwargs ):
265+ V_target = V if isinstance (V , ufl .FunctionSpace ) else V .function_space ()
265266 if not isinstance (expr , ufl .Interpolate ):
266- expr = interpolate (expr , V if isinstance ( V , ufl . FunctionSpace ) else V . function_space () )
267+ expr = interpolate (expr , V_target )
267268
268269 spaces = [a .function_space () for a in expr .arguments ()]
269270 has_mixed_spaces = any (len (space ) > 1 for space in spaces )
@@ -280,10 +281,9 @@ def __new__(cls, expr, V, **kwargs):
280281 if target_mesh is source_mesh or submesh_interp_implemented :
281282 return object .__new__ (SameMeshInterpolator )
282283 else :
283- needs_adjoint = not isinstance (expr .arguments ()[0 ], Coargument )
284284 if isinstance (target_mesh .topology , VertexOnlyMeshTopology ):
285285 return object .__new__ (SameMeshInterpolator )
286- elif has_mixed_spaces and needs_adjoint :
286+ elif has_mixed_spaces or len ( V_target ) > 1 :
287287 return object .__new__ (MixedInterpolator )
288288 else :
289289 return object .__new__ (CrossMeshInterpolator )
@@ -506,62 +506,24 @@ def __init__(
506506 self .src_mesh = src_mesh
507507 self .dest_mesh = dest_mesh
508508
509- self .sub_interpolators = []
510-
511509 # Create a VOM at the nodes of V_dest in src_mesh. We don't include halo
512510 # node coordinates because interpolation doesn't usually include halos.
513511 # NOTE: it is very important to set redundant=False, otherwise the
514512 # input ordering VOM will only contain the points on rank 0!
515513 # QUESTION: Should any of the below have annotation turned off?
516514 ufl_scalar_element = V_dest .ufl_element ()
517515 if isinstance (ufl_scalar_element , finat .ufl .MixedElement ):
518- if all (
519- ufl_scalar_element .sub_elements [0 ] == e
520- for e in ufl_scalar_element .sub_elements
521- ):
522- # For a VectorElement or TensorElement the correct
523- # VectorFunctionSpace equivalent is built from the scalar
524- # sub-element.
525- ufl_scalar_element = ufl_scalar_element .sub_elements [0 ]
526- if ufl_scalar_element .reference_value_shape != ():
527- raise NotImplementedError (
528- "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
529- )
530- else :
531- # Build and save an interpolator for each sub-element
532- # separately for MixedFunctionSpaces. NOTE: since we can't have
533- # expressions for MixedFunctionSpaces we know that the input
534- # argument ``expr`` must be a Function. V_dest can be a Function
535- # or a FunctionSpace, and subfunctions works for both.
536- if self .nargs == 1 :
537- # Arguments don't have a subfunctions property so I have to
538- # make them myself. NOTE: this will not be correct when we
539- # start allowing interpolators created from an expression
540- # with arguments, as opposed to just being the argument.
541- expr_subfunctions = [
542- firedrake .TestFunction (V_src_sub_func )
543- for V_src_sub_func in self .expr .function_space ().subspaces
544- ]
545- elif self .nargs > 1 :
546- raise NotImplementedError (
547- "Can't yet create an interpolator from an expression with multiple arguments."
548- )
549- else :
550- expr_subfunctions = self .expr .subfunctions
551- if len (expr_subfunctions ) != len (V_dest .subspaces ):
552- raise NotImplementedError (
553- "Can't interpolate from a non-mixed function space into a mixed function space."
554- )
555- for input_sub_func , target_subspace in zip (
556- expr_subfunctions , V_dest .subspaces
557- ):
558- self .sub_interpolators .append (
559- interpolate (
560- input_sub_func , target_subspace , subset = subset ,
561- access = access , allow_missing_dofs = allow_missing_dofs
562- )
563- )
564- return
516+ if type (ufl_scalar_element ) == finat .ufl .MixedElement :
517+ raise NotImplementedError ("Need a MixedInterpolator" )
518+
519+ # For a VectorElement or TensorElement the correct
520+ # VectorFunctionSpace equivalent is built from the scalar
521+ # sub-element.
522+ ufl_scalar_element = ufl_scalar_element .sub_elements [0 ]
523+ if ufl_scalar_element .reference_value_shape != ():
524+ raise NotImplementedError (
525+ "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
526+ )
565527
566528 from firedrake .assemble import assemble
567529 V_dest_vec = firedrake .VectorFunctionSpace (dest_mesh , ufl_scalar_element )
@@ -672,21 +634,6 @@ def _interpolate(
672634 else :
673635 output = firedrake .Function (V_dest )
674636
675- if len (self .sub_interpolators ):
676- # MixedFunctionSpace case
677- for sub_interpolate , f_src_sub_func , output_sub_func in zip (
678- self .sub_interpolators , f_src .subfunctions , output .subfunctions
679- ):
680- if f_src is self .expr :
681- # f_src is already contained in self.point_eval_interpolate,
682- # so the sub_interpolators are already prepared to interpolate
683- # without needing to be given a Function
684- assert not self .nargs
685- assemble (sub_interpolate , tensor = output_sub_func )
686- else :
687- assemble (action (sub_interpolate , f_src_sub_func ), tensor = output_sub_func )
688- return output
689-
690637 if not adjoint :
691638 if f_src is self .expr :
692639 # f_src is already contained in self.point_eval_interpolate
@@ -1748,7 +1695,9 @@ def __init__(self, expr, V, bcs=None, **kwargs):
17481695 expr = self .ufl_interpolate
17491696 self .arguments = expr .arguments ()
17501697 rank = len (self .arguments )
1751- if rank < 2 :
1698+
1699+ needs_action = len ([a for a in self .arguments if isinstance (a , Coargument )]) == 0
1700+ if needs_action :
17521701 dual_arg , operand = expr .argument_slots ()
17531702 # Split the dual argument
17541703 dual_split = dict (firedrake .formmanipulation .split_form (dual_arg ))
@@ -1768,7 +1717,7 @@ def __init__(self, expr, V, bcs=None, **kwargs):
17681717 sub_bcs = [bc for bc in bcs if bc .function_space () in {Vsource , Vtarget }]
17691718 else :
17701719 sub_bcs = None
1771- if rank < 2 :
1720+ if needs_action :
17721721 # Take the action of each sub-cofunction against each block
17731722 form = action (form , dual_split [indices [1 :]])
17741723
@@ -1805,8 +1754,9 @@ def _interpolate(self, *function, output=None, adjoint=False, **kwargs):
18051754
18061755 if rank == 1 :
18071756 for k , sub_tensor in enumerate (output .subfunctions ):
1808- sub_tensor .assign (sum (self [i , j ].assemble (** kwargs ) for ( i , j ) in self if i == k ))
1757+ sub_tensor .assign (sum (self [i ].assemble (** kwargs ) for i in self if i [ 0 ] == k ))
18091758 elif rank == 2 :
18101759 for k , sub_tensor in enumerate (output .subfunctions ):
1811- sub_tensor .assign (sum (self [i , j ]._interpolate (* function , adjoint = adjoint , ** kwargs ) for (i , j ) in self if i == k ))
1760+ sub_tensor .assign (sum (self [i ]._interpolate (* function , adjoint = adjoint , ** kwargs )
1761+ for i in self if i [0 ] == k ))
18121762 return output
0 commit comments