@@ -262,30 +262,29 @@ class Interpolator(abc.ABC):
262262 """
263263
264264 def __new__ (cls , expr , V , ** kwargs ):
265- if isinstance (expr , ufl .Interpolate ):
266- # MixedFunctionSpace is only implemented for the primal 1-form.
267- # Are we a 2-form or a dual 1-form?
268- arguments = expr .arguments ()
269- if any (not isinstance (a , Coargument ) for a in arguments ):
270- # Do we have mixed source or target spaces?
271- spaces = [a .function_space () for a in arguments ]
272- if len (spaces ) < 2 :
273- spaces .append (V )
274- if any (len (space ) > 1 for space in spaces ):
275- return object .__new__ (MixedInterpolator )
276- expr , = expr .ufl_operands
265+ if not isinstance (expr , ufl .Interpolate ):
266+ expr = interpolate (expr , V if isinstance (V , ufl .FunctionSpace ) else V .function_space ())
267+
268+ spaces = [a .function_space () for a in expr .arguments ()]
269+ has_mixed_spaces = any (len (space ) > 1 for space in spaces )
270+ if len (spaces ) == 2 and has_mixed_spaces :
271+ return object .__new__ (MixedInterpolator )
277272
273+ operand , = expr .ufl_operands
278274 target_mesh = as_domain (V )
279- source_mesh = extract_unique_domain (expr ) or target_mesh
275+ source_mesh = extract_unique_domain (operand ) or target_mesh
280276 submesh_interp_implemented = \
281277 all (isinstance (m .topology , firedrake .mesh .MeshTopology ) for m in [target_mesh , source_mesh ]) and \
282278 target_mesh .submesh_ancesters [- 1 ] is source_mesh .submesh_ancesters [- 1 ] and \
283279 target_mesh .topological_dimension () == source_mesh .topological_dimension ()
284280 if target_mesh is source_mesh or submesh_interp_implemented :
285281 return object .__new__ (SameMeshInterpolator )
286282 else :
283+ needs_adjoint = not isinstance (expr .arguments ()[0 ], Coargument )
287284 if isinstance (target_mesh .topology , VertexOnlyMeshTopology ):
288285 return object .__new__ (SameMeshInterpolator )
286+ elif has_mixed_spaces and needs_adjoint :
287+ return object .__new__ (MixedInterpolator )
289288 else :
290289 return object .__new__ (CrossMeshInterpolator )
291290
@@ -301,8 +300,7 @@ def __init__(
301300 matfree : bool = True
302301 ):
303302 if not isinstance (expr , ufl .Interpolate ):
304- fs = V if isinstance (V , ufl .FunctionSpace ) else V .function_space ()
305- expr = interpolate (expr , fs )
303+ expr = interpolate (expr , V if isinstance (V , ufl .FunctionSpace ) else V .function_space ())
306304 dual_arg , operand = expr .argument_slots ()
307305 self .ufl_interpolate = expr
308306 self .expr = operand
@@ -414,8 +412,7 @@ def assemble(self, tensor=None, **kwargs):
414412 Iu = self ._interpolate (** kwargs )
415413 return assemble (ufl .Action (* cofunctions , Iu ), tensor = tensor )
416414 else :
417- return self ._interpolate (* cofunctions , output = tensor , adjoint = needs_adjoint ,
418- ** kwargs )
415+ return self ._interpolate (* cofunctions , output = tensor , adjoint = needs_adjoint , ** kwargs )
419416
420417
421418class DofNotDefinedError (Exception ):
@@ -989,21 +986,33 @@ def callable():
989986 if access == op2 .INC :
990987 loops .append (tensor .zero )
991988
992- if rank == 0 and len (V ) > 1 :
993- dual_arg , operand = expr .argument_slots ()
989+ dual_arg , operand = expr .argument_slots ()
990+ # Any arguments in the operand may be from a MixedFunctoinSpace
991+ # We need to split the target space V and generate separate kernels
992+ if len (V ) == 1 :
993+ expressions = {(0 ,): expr }
994+ elif isinstance (dual_arg , Coargument ):
995+ # Split in the coargument
996+ expressions = dict (firedrake .formmanipulation .split_form (expr ))
997+ else :
998+ # Split in the cofunction: split_form can only split in the coargument
999+ # Replace the cofunction with a coargument to construct the Jacobian
9941000 interp = expr ._ufl_expr_reconstruct_ (operand , V )
1001+ # Split the Jacobian into blocks
9951002 interp_split = dict (firedrake .formmanipulation .split_form (interp ))
1003+ # Split the cofunction
9961004 dual_split = dict (firedrake .formmanipulation .split_form (dual_arg ))
997- expressions = {i : action (interp_split [i ], dual_split [i ]) for i in dual_split }
998- elif len (V ) > 1 :
999- expressions = dict (firedrake .formmanipulation .split_form (expr ))
1000- else :
1001- expressions = {(0 ,): expr }
1005+ # Combine the splits by taking their action
1006+ expressions = {i : action (interp_split [i ], dual_split [i [- 1 :]]) for i in interp_split }
10021007
10031008 # Interpolate each sub expression into each function space
1004- for (i ,), sub_expr in expressions .items ():
1005- sub_tensor = tensor [i ] if (rank == 1 and len (V ) > 1 ) else tensor
1006- loops .extend (_interpolator (V [i ], sub_tensor , sub_expr , subset , arguments , access , bcs = bcs ))
1009+ for indices , sub_expr in expressions .items ():
1010+ if isinstance (sub_expr , ufl .ZeroBaseForm ):
1011+ continue
1012+ arguments = sub_expr .arguments ()
1013+ sub_space = sub_expr .argument_slots ()[0 ].function_space ().dual ()
1014+ sub_tensor = tensor [indices [0 ]] if rank == 1 else tensor
1015+ loops .extend (_interpolator (sub_space , sub_tensor , sub_expr , subset , arguments , access , bcs = bcs ))
10071016
10081017 if bcs and rank == 1 :
10091018 loops .extend (partial (bc .apply , f ) for bc in bcs )
0 commit comments