Skip to content

Commit fee366a

Browse files
committed
Implement missing functionality in CrossMeshInterpolator
1 parent e400fc9 commit fee366a

File tree

2 files changed

+28
-75
lines changed

2 files changed

+28
-75
lines changed

firedrake/interpolation.py

Lines changed: 21 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,9 @@ class Interpolator(abc.ABC):
262262
"""
263263

264264
def __new__(cls, expr, V, **kwargs):
265+
V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space()
265266
if not isinstance(expr, ufl.Interpolate):
266-
expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space())
267+
expr = interpolate(expr, V_target)
267268

268269
spaces = [a.function_space() for a in expr.arguments()]
269270
has_mixed_spaces = any(len(space) > 1 for space in spaces)
@@ -280,10 +281,9 @@ def __new__(cls, expr, V, **kwargs):
280281
if target_mesh is source_mesh or submesh_interp_implemented:
281282
return object.__new__(SameMeshInterpolator)
282283
else:
283-
needs_adjoint = not isinstance(expr.arguments()[0], Coargument)
284284
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
285285
return object.__new__(SameMeshInterpolator)
286-
elif has_mixed_spaces and needs_adjoint:
286+
elif has_mixed_spaces or len(V_target) > 1:
287287
return object.__new__(MixedInterpolator)
288288
else:
289289
return object.__new__(CrossMeshInterpolator)
@@ -506,62 +506,24 @@ def __init__(
506506
self.src_mesh = src_mesh
507507
self.dest_mesh = dest_mesh
508508

509-
self.sub_interpolators = []
510-
511509
# Create a VOM at the nodes of V_dest in src_mesh. We don't include halo
512510
# node coordinates because interpolation doesn't usually include halos.
513511
# NOTE: it is very important to set redundant=False, otherwise the
514512
# input ordering VOM will only contain the points on rank 0!
515513
# QUESTION: Should any of the below have annotation turned off?
516514
ufl_scalar_element = V_dest.ufl_element()
517515
if isinstance(ufl_scalar_element, finat.ufl.MixedElement):
518-
if all(
519-
ufl_scalar_element.sub_elements[0] == e
520-
for e in ufl_scalar_element.sub_elements
521-
):
522-
# For a VectorElement or TensorElement the correct
523-
# VectorFunctionSpace equivalent is built from the scalar
524-
# sub-element.
525-
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
526-
if ufl_scalar_element.reference_value_shape != ():
527-
raise NotImplementedError(
528-
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
529-
)
530-
else:
531-
# Build and save an interpolator for each sub-element
532-
# separately for MixedFunctionSpaces. NOTE: since we can't have
533-
# expressions for MixedFunctionSpaces we know that the input
534-
# argument ``expr`` must be a Function. V_dest can be a Function
535-
# or a FunctionSpace, and subfunctions works for both.
536-
if self.nargs == 1:
537-
# Arguments don't have a subfunctions property so I have to
538-
# make them myself. NOTE: this will not be correct when we
539-
# start allowing interpolators created from an expression
540-
# with arguments, as opposed to just being the argument.
541-
expr_subfunctions = [
542-
firedrake.TestFunction(V_src_sub_func)
543-
for V_src_sub_func in self.expr.function_space().subspaces
544-
]
545-
elif self.nargs > 1:
546-
raise NotImplementedError(
547-
"Can't yet create an interpolator from an expression with multiple arguments."
548-
)
549-
else:
550-
expr_subfunctions = self.expr.subfunctions
551-
if len(expr_subfunctions) != len(V_dest.subspaces):
552-
raise NotImplementedError(
553-
"Can't interpolate from a non-mixed function space into a mixed function space."
554-
)
555-
for input_sub_func, target_subspace in zip(
556-
expr_subfunctions, V_dest.subspaces
557-
):
558-
self.sub_interpolators.append(
559-
interpolate(
560-
input_sub_func, target_subspace, subset=subset,
561-
access=access, allow_missing_dofs=allow_missing_dofs
562-
)
563-
)
564-
return
516+
if type(ufl_scalar_element) == finat.ufl.MixedElement:
517+
raise NotImplementedError("Need a MixedInterpolator")
518+
519+
# For a VectorElement or TensorElement the correct
520+
# VectorFunctionSpace equivalent is built from the scalar
521+
# sub-element.
522+
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
523+
if ufl_scalar_element.reference_value_shape != ():
524+
raise NotImplementedError(
525+
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
526+
)
565527

566528
from firedrake.assemble import assemble
567529
V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element)
@@ -672,21 +634,6 @@ def _interpolate(
672634
else:
673635
output = firedrake.Function(V_dest)
674636

675-
if len(self.sub_interpolators):
676-
# MixedFunctionSpace case
677-
for sub_interpolate, f_src_sub_func, output_sub_func in zip(
678-
self.sub_interpolators, f_src.subfunctions, output.subfunctions
679-
):
680-
if f_src is self.expr:
681-
# f_src is already contained in self.point_eval_interpolate,
682-
# so the sub_interpolators are already prepared to interpolate
683-
# without needing to be given a Function
684-
assert not self.nargs
685-
assemble(sub_interpolate, tensor=output_sub_func)
686-
else:
687-
assemble(action(sub_interpolate, f_src_sub_func), tensor=output_sub_func)
688-
return output
689-
690637
if not adjoint:
691638
if f_src is self.expr:
692639
# f_src is already contained in self.point_eval_interpolate
@@ -1748,7 +1695,9 @@ def __init__(self, expr, V, bcs=None, **kwargs):
17481695
expr = self.ufl_interpolate
17491696
self.arguments = expr.arguments()
17501697
rank = len(self.arguments)
1751-
if rank < 2:
1698+
1699+
needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0
1700+
if needs_action:
17521701
dual_arg, operand = expr.argument_slots()
17531702
# Split the dual argument
17541703
dual_split = dict(firedrake.formmanipulation.split_form(dual_arg))
@@ -1768,7 +1717,7 @@ def __init__(self, expr, V, bcs=None, **kwargs):
17681717
sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}]
17691718
else:
17701719
sub_bcs = None
1771-
if rank < 2:
1720+
if needs_action:
17721721
# Take the action of each sub-cofunction against each block
17731722
form = action(form, dual_split[indices[1:]])
17741723

@@ -1805,8 +1754,9 @@ def _interpolate(self, *function, output=None, adjoint=False, **kwargs):
18051754

18061755
if rank == 1:
18071756
for k, sub_tensor in enumerate(output.subfunctions):
1808-
sub_tensor.assign(sum(self[i, j].assemble(**kwargs) for (i, j) in self if i == k))
1757+
sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k))
18091758
elif rank == 2:
18101759
for k, sub_tensor in enumerate(output.subfunctions):
1811-
sub_tensor.assign(sum(self[i, j]._interpolate(*function, adjoint=adjoint, **kwargs) for (i, j) in self if i == k))
1760+
sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs)
1761+
for i in self if i[0] == k))
18121762
return output

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,15 @@ def test_interpolate_unitsquare_mixed():
299299
assert not np.allclose(f_src.dat.data_ro[0], cofunc_src.dat.data_ro[0])
300300
assert not np.allclose(f_src.dat.data_ro[1], cofunc_src.dat.data_ro[1])
301301

302-
# Can't go from non-mixed to mixed
302+
# Interpolate from non-mixed to mixed
303303
V_src_2 = VectorFunctionSpace(m_src, "CG", 1)
304304
assert V_src_2.value_shape == V_src.value_shape
305-
f_src_2 = Function(V_src_2)
306-
with pytest.raises(NotImplementedError):
307-
assemble(interpolate(f_src_2, V_dest))
305+
f_src_2 = Function(V_src_2).interpolate(SpatialCoordinate(m_src))
306+
result_mixed = assemble(interpolate(f_src_2, V_dest))
307+
308+
for i in range(len(V_dest)):
309+
expected = assemble(interpolate(f_src_2[i], V_dest[i]))
310+
assert np.allclose(result_mixed.dat.data_ro[i], expected.dat.data_ro)
308311

309312

310313
@pytest.mark.parallel([1, 3])

0 commit comments

Comments
 (0)