Skip to content

Commit 9a703c3

Browse files
authored
Stop renumbering arguments in the expression to interpolate (#4582)
1 parent 2140635 commit 9a703c3

File tree

22 files changed

+85
-112
lines changed

22 files changed

+85
-112
lines changed

demos/boussinesq/boussinesq.py.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ implements a boundary condition that fixes a field at a single point. ::
184184

185185
# Take the basis function with the largest abs value at bc_point
186186
v = TestFunction(V)
187-
F = assemble(Interpolate(inner(v, v), Fvom))
187+
F = assemble(interpolate(inner(v, v), Fvom))
188188
with F.dat.vec as Fvec:
189189
max_index, _ = Fvec.max()
190190
nodes = V.dof_dset.lgmap.applyInverse([max_index])

demos/multicomponent/multicomponent.py.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ mathematically valid to do this)::
523523

524524
# Take the basis function with the largest abs value at bc_point
525525
v = TestFunction(V)
526-
F = assemble(Interpolate(inner(v, v), Fvom))
526+
F = assemble(interpolate(inner(v, v), Fvom))
527527
with F.dat.vec as Fvec:
528528
max_index, _ = Fvec.max()
529529
nodes = V.dof_dset.lgmap.applyInverse([max_index])

firedrake/cofunction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def interpolate(self,
340340
Parameters
341341
----------
342342
expression
343-
A dual UFL expression to interpolate.
343+
A UFL BaseForm to adjoint interpolate.
344344
ad_block_tag
345345
An optional string for tagging the resulting assemble
346346
block on the Pyadjoint tape.
@@ -353,9 +353,9 @@ def interpolate(self,
353353
firedrake.cofunction.Cofunction
354354
Returns `self`
355355
"""
356-
from firedrake import interpolation, assemble
356+
from firedrake import interpolate, assemble
357357
v, = self.arguments()
358-
interp = interpolation.Interpolate(v, expression, **kwargs)
358+
interp = interpolate(v, expression, **kwargs)
359359
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)
360360

361361
@property

firedrake/external_operators/point_expr_operator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import firedrake.ufl_expr as ufl_expr
77
from firedrake.assemble import assemble
8-
from firedrake.interpolation import Interpolate
8+
from firedrake.interpolation import interpolate
99
from firedrake.external_operators import AbstractExternalOperator, assemble_method
1010

1111

@@ -58,7 +58,7 @@ def assemble_operator(self, *args, **kwargs):
5858
V = self.function_space()
5959
expr = as_ufl(self.expr(*self.ufl_operands))
6060
if len(V) < 2:
61-
interp = Interpolate(expr, self.function_space())
61+
interp = interpolate(expr, self.function_space())
6262
return assemble(interp)
6363
# Interpolation of UFL expressions for mixed functions is not yet supported
6464
# -> `Function.assign` might be enough in some cases.
@@ -72,7 +72,7 @@ def assemble_operator(self, *args, **kwargs):
7272
def assemble_Jacobian_action(self, *args, **kwargs):
7373
V = self.function_space()
7474
expr = as_ufl(self.expr(*self.ufl_operands))
75-
interp = Interpolate(expr, V)
75+
interp = interpolate(expr, V)
7676

7777
u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
7878
w = self.argument_slots()[-1]
@@ -83,7 +83,7 @@ def assemble_Jacobian_action(self, *args, **kwargs):
8383
def assemble_Jacobian(self, *args, assembly_opts, **kwargs):
8484
V = self.function_space()
8585
expr = as_ufl(self.expr(*self.ufl_operands))
86-
interp = Interpolate(expr, V)
86+
interp = interpolate(expr, V)
8787

8888
u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
8989
jac = ufl_expr.derivative(interp, u)
@@ -99,7 +99,7 @@ def assemble_Jacobian_adjoint(self, *args, assembly_opts, **kwargs):
9999
def assemble_Jacobian_adjoint_action(self, *args, **kwargs):
100100
V = self.function_space()
101101
expr = as_ufl(self.expr(*self.ufl_operands))
102-
interp = Interpolate(expr, V)
102+
interp = interpolate(expr, V)
103103

104104
u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
105105
ustar = self.argument_slots()[0]

firedrake/function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,9 @@ def interpolate(self,
382382
firedrake.function.Function
383383
Returns `self`
384384
"""
385-
from firedrake import interpolation, assemble
385+
from firedrake import interpolate, assemble
386386
V = self.function_space()
387-
interp = interpolation.Interpolate(expression, V, **kwargs)
387+
interp = interpolate(expression, V, **kwargs)
388388
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)
389389

390390
def zero(self, subset=None):
@@ -715,7 +715,7 @@ def __init__(self, domain, point):
715715
self.point = point
716716

717717
def __str__(self):
718-
return "domain %s does not contain point %s" % (self.domain, self.point)
718+
return f"Domain {self.domain} does not contain point {self.point}"
719719

720720

721721
class PointEvaluator:

firedrake/interpolation.py

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
import abc
55
import warnings
66
from collections.abc import Iterable
7-
from typing import Literal
87
from functools import partial, singledispatch
9-
from typing import Hashable
8+
from typing import Hashable, Literal
109

1110
import FIAT
1211
import ufl
1312
import finat.ufl
14-
from ufl.algorithms import extract_arguments, extract_coefficients, replace
13+
from ufl.algorithms import extract_arguments, extract_coefficients
1514
from ufl.domain import as_domain, extract_unique_domain
1615

1716
from pyop2 import op2
@@ -25,13 +24,11 @@
2524
import finat
2625

2726
import firedrake
28-
import firedrake.bcs
2927
from firedrake import tsfc_interface, utils, functionspaceimpl
3028
from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint
3129
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology
3230
from firedrake.petsc import PETSc
3331
from firedrake.halo import _get_mtype as get_dat_mpi_type
34-
from firedrake.cofunction import Cofunction
3532
from mpi4py import MPI
3633

3734
from pyadjoint import stop_annotating, no_annotations
@@ -48,7 +45,7 @@
4845

4946
class Interpolate(ufl.Interpolate):
5047

51-
def __init__(self, expr, v,
48+
def __init__(self, expr, V,
5249
subset=None,
5350
access=None,
5451
allow_missing_dofs=False,
@@ -60,7 +57,7 @@ def __init__(self, expr, v,
6057
----------
6158
expr : ufl.core.expr.Expr or ufl.BaseForm
6259
The UFL expression to interpolate.
63-
v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
60+
V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
6461
The function space to interpolate into or the coargument defined
6562
on the dual of the function space to interpolate into.
6663
subset : pyop2.types.set.Subset
@@ -95,20 +92,18 @@ def __init__(self, expr, v,
9592
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
9693
and reduce operations.
9794
"""
98-
# Check function space
9995
expr = ufl.as_ufl(expr)
100-
if isinstance(v, functionspaceimpl.WithGeometry):
101-
expr_args = extract_arguments(expr)
102-
is_adjoint = len(expr_args) and expr_args[0].number() == 0
103-
v = Argument(v.dual(), 1 if is_adjoint else 0)
96+
if isinstance(V, functionspaceimpl.WithGeometry):
97+
expr_args = expr.arguments()[1:] if isinstance(expr, ufl.BaseForm) else extract_arguments(expr)
98+
expr_arg_numbers = {arg.number() for arg in expr_args}
99+
# Need to create a Firedrake Argument so that it has a .function_space() method
100+
V = Argument(V.dual(), 1 if expr_arg_numbers == {0} else 0)
104101

105-
V = v.arguments()[0].function_space()
106-
if len(expr.ufl_shape) != len(V.value_shape):
107-
raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}')
102+
target_shape = V.arguments()[0].function_space().value_shape
103+
if expr.ufl_shape != target_shape:
104+
raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.")
108105

109-
if expr.ufl_shape != V.value_shape:
110-
raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}')
111-
super().__init__(expr, v)
106+
super().__init__(expr, V)
112107

113108
# -- Interpolate data (e.g. `subset` or `access`) -- #
114109
self.interp_data = {"subset": subset,
@@ -174,32 +169,10 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def
174169
reduction (hence using MIN will compute the MIN between the
175170
existing values and any new values).
176171
"""
177-
if isinstance(V, (Cofunction, Coargument)):
178-
dual_arg = V
179-
elif isinstance(V, ufl.BaseForm):
180-
rank = len(V.arguments())
181-
if rank == 1:
182-
dual_arg = V
183-
else:
184-
raise TypeError(f"Expected a one-form, provided form had {rank} arguments")
185-
elif isinstance(V, functionspaceimpl.WithGeometry):
186-
dual_arg = Coargument(V.dual(), 0)
187-
expr_args = extract_arguments(ufl.as_ufl(expr))
188-
if expr_args and expr_args[0].number() == 0:
189-
warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. "
190-
"Use a TrialFunction in the expression.")
191-
v, = expr_args
192-
expr = replace(expr, {v: v.reconstruct(number=1)})
193-
else:
194-
raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}")
195-
196-
interp = Interpolate(expr, dual_arg,
197-
subset=subset, access=access,
198-
allow_missing_dofs=allow_missing_dofs,
199-
default_missing_val=default_missing_val,
200-
matfree=matfree)
201-
202-
return interp
172+
return Interpolate(
173+
expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs,
174+
default_missing_val=default_missing_val, matfree=matfree
175+
)
203176

204177

205178
class Interpolator(abc.ABC):
@@ -528,7 +501,7 @@ def __init__(
528501

529502
from firedrake.assemble import assemble
530503
V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element)
531-
f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec)
504+
f_dest_node_coords = interpolate(dest_mesh.coordinates, V_dest_vec)
532505
f_dest_node_coords = assemble(f_dest_node_coords)
533506
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim)
534507
try:
@@ -553,15 +526,15 @@ def __init__(
553526
else:
554527
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape)
555528
P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0)
556-
self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom)
529+
self.point_eval_interpolate = interpolate(self.expr_renumbered, P0DG_vom)
557530
# The parallel decomposition of the nodes of V_dest in the DESTINATION
558531
# mesh (dest_mesh) is retrieved using the input_ordering attribute of the
559532
# VOM. This again is an interpolation operation, which, under the hood
560533
# is a PETSc SF reduce.
561534
P0DG_vom_i_o = fs_type(
562535
self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0
563536
)
564-
self.to_input_ordering_interpolate = Interpolate(
537+
self.to_input_ordering_interpolate = interpolate(
565538
firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o
566539
)
567540
# The P0DG function outputted by the above interpolation has the

firedrake/mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4274,7 +4274,7 @@ def _parent_mesh_embedding(
42744274
# nessesary, to other processes.
42754275
P0DG = functionspace.FunctionSpace(parent_mesh, "DG", 0)
42764276
with stop_annotating():
4277-
visible_ranks = interpolation.Interpolate(
4277+
visible_ranks = interpolation.interpolate(
42784278
constant.Constant(parent_mesh.comm.rank), P0DG
42794279
)
42804280
visible_ranks = assemble(visible_ranks).dat.data_ro_with_halos.real

firedrake/mg/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def physical_node_locations(V):
143143
Vc = V.collapse().reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension))
144144

145145
# FIXME: This is unsafe for DG coordinates and CG target spaces.
146-
locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc))
146+
locations = firedrake.assemble(firedrake.interpolate(firedrake.SpatialCoordinate(mesh), Vc))
147147
return cache.setdefault(key, locations)
148148

149149

firedrake/preconditioners/gtmg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from firedrake.petsc import PETSc
55
from firedrake.preconditioners.base import PCBase
66
from firedrake.parameters import parameters
7-
from firedrake.interpolation import Interpolate
7+
from firedrake.interpolation import interpolate
88
from firedrake.solving_utils import _SNESContext
99
from firedrake.matrix_free.operators import ImplicitMatrixContext
1010
import firedrake.dmhooks as dmhooks
@@ -155,7 +155,7 @@ def initialize(self, pc):
155155
# Create interpolation matrix from coarse space to fine space
156156
fine_space = ctx.J.arguments()[0].function_space()
157157
coarse_test, coarse_trial = coarse_operator.arguments()
158-
interp = assemble(Interpolate(coarse_trial, fine_space))
158+
interp = assemble(interpolate(coarse_trial, fine_space))
159159
interp_petscmat = interp.petscmat
160160
restr_petscmat = appctx.get("restriction_matrix", None)
161161

firedrake/preconditioners/hypre_ads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from firedrake.preconditioners.base import PCBase
22
from firedrake.petsc import PETSc
33
from firedrake.function import Function
4-
from firedrake.ufl_expr import TestFunction
4+
from firedrake.ufl_expr import TrialFunction
55
from firedrake.dmhooks import get_function_space
66
from firedrake.preconditioners.hypre_ams import chop
77
from firedrake.interpolation import interpolate
@@ -31,12 +31,12 @@ def initialize(self, obj):
3131
NC1 = V.reconstruct(family="N1curl" if mesh.ufl_cell().is_simplex else "NCE", degree=1)
3232
G_callback = appctx.get("get_gradient", None)
3333
if G_callback is None:
34-
G = chop(assemble(interpolate(grad(TestFunction(P1)), NC1)).petscmat)
34+
G = chop(assemble(interpolate(grad(TrialFunction(P1)), NC1)).petscmat)
3535
else:
3636
G = G_callback(P1, NC1)
3737
C_callback = appctx.get("get_curl", None)
3838
if C_callback is None:
39-
C = chop(assemble(interpolate(curl(TestFunction(NC1)), V)).petscmat)
39+
C = chop(assemble(interpolate(curl(TrialFunction(NC1)), V)).petscmat)
4040
else:
4141
C = C_callback(NC1, V)
4242

0 commit comments

Comments
 (0)