Skip to content

Commit e400fc9

Browse files
committed
cleanup
1 parent 997e638 commit e400fc9

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

firedrake/interpolation.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -262,30 +262,29 @@ class Interpolator(abc.ABC):
262262
"""
263263

264264
def __new__(cls, expr, V, **kwargs):
265-
if isinstance(expr, ufl.Interpolate):
266-
# MixedFunctionSpace is only implemented for the primal 1-form.
267-
# Are we a 2-form or a dual 1-form?
268-
arguments = expr.arguments()
269-
if any(not isinstance(a, Coargument) for a in arguments):
270-
# Do we have mixed source or target spaces?
271-
spaces = [a.function_space() for a in arguments]
272-
if len(spaces) < 2:
273-
spaces.append(V)
274-
if any(len(space) > 1 for space in spaces):
275-
return object.__new__(MixedInterpolator)
276-
expr, = expr.ufl_operands
265+
if not isinstance(expr, ufl.Interpolate):
266+
expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space())
267+
268+
spaces = [a.function_space() for a in expr.arguments()]
269+
has_mixed_spaces = any(len(space) > 1 for space in spaces)
270+
if len(spaces) == 2 and has_mixed_spaces:
271+
return object.__new__(MixedInterpolator)
277272

273+
operand, = expr.ufl_operands
278274
target_mesh = as_domain(V)
279-
source_mesh = extract_unique_domain(expr) or target_mesh
275+
source_mesh = extract_unique_domain(operand) or target_mesh
280276
submesh_interp_implemented = \
281277
all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \
282278
target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \
283279
target_mesh.topological_dimension() == source_mesh.topological_dimension()
284280
if target_mesh is source_mesh or submesh_interp_implemented:
285281
return object.__new__(SameMeshInterpolator)
286282
else:
283+
needs_adjoint = not isinstance(expr.arguments()[0], Coargument)
287284
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
288285
return object.__new__(SameMeshInterpolator)
286+
elif has_mixed_spaces and needs_adjoint:
287+
return object.__new__(MixedInterpolator)
289288
else:
290289
return object.__new__(CrossMeshInterpolator)
291290

@@ -301,8 +300,7 @@ def __init__(
301300
matfree: bool = True
302301
):
303302
if not isinstance(expr, ufl.Interpolate):
304-
fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space()
305-
expr = interpolate(expr, fs)
303+
expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space())
306304
dual_arg, operand = expr.argument_slots()
307305
self.ufl_interpolate = expr
308306
self.expr = operand
@@ -414,8 +412,7 @@ def assemble(self, tensor=None, **kwargs):
414412
Iu = self._interpolate(**kwargs)
415413
return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor)
416414
else:
417-
return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint,
418-
**kwargs)
415+
return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, **kwargs)
419416

420417

421418
class DofNotDefinedError(Exception):
@@ -989,21 +986,33 @@ def callable():
989986
if access == op2.INC:
990987
loops.append(tensor.zero)
991988

992-
if rank == 0 and len(V) > 1:
993-
dual_arg, operand = expr.argument_slots()
989+
dual_arg, operand = expr.argument_slots()
990+
# Any arguments in the operand may be from a MixedFunctoinSpace
991+
# We need to split the target space V and generate separate kernels
992+
if len(V) == 1:
993+
expressions = {(0,): expr}
994+
elif isinstance(dual_arg, Coargument):
995+
# Split in the coargument
996+
expressions = dict(firedrake.formmanipulation.split_form(expr))
997+
else:
998+
# Split in the cofunction: split_form can only split in the coargument
999+
# Replace the cofunction with a coargument to construct the Jacobian
9941000
interp = expr._ufl_expr_reconstruct_(operand, V)
1001+
# Split the Jacobian into blocks
9951002
interp_split = dict(firedrake.formmanipulation.split_form(interp))
1003+
# Split the cofunction
9961004
dual_split = dict(firedrake.formmanipulation.split_form(dual_arg))
997-
expressions = {i: action(interp_split[i], dual_split[i]) for i in dual_split}
998-
elif len(V) > 1:
999-
expressions = dict(firedrake.formmanipulation.split_form(expr))
1000-
else:
1001-
expressions = {(0,): expr}
1005+
# Combine the splits by taking their action
1006+
expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split}
10021007

10031008
# Interpolate each sub expression into each function space
1004-
for (i,), sub_expr in expressions.items():
1005-
sub_tensor = tensor[i] if (rank == 1 and len(V) > 1) else tensor
1006-
loops.extend(_interpolator(V[i], sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
1009+
for indices, sub_expr in expressions.items():
1010+
if isinstance(sub_expr, ufl.ZeroBaseForm):
1011+
continue
1012+
arguments = sub_expr.arguments()
1013+
sub_space = sub_expr.argument_slots()[0].function_space().dual()
1014+
sub_tensor = tensor[indices[0]] if rank == 1 else tensor
1015+
loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
10071016

10081017
if bcs and rank == 1:
10091018
loops.extend(partial(bc.apply, f) for bc in bcs)

0 commit comments

Comments
 (0)