44import abc
55import warnings
66from collections .abc import Iterable
7- from typing import Literal
87from functools import partial , singledispatch
9- from typing import Hashable
8+ from typing import Hashable , Literal
109
1110import FIAT
1211import ufl
1312import finat .ufl
14- from ufl .algorithms import extract_arguments , extract_coefficients , replace
13+ from ufl .algorithms import extract_arguments , extract_coefficients
1514from ufl .domain import as_domain , extract_unique_domain
1615
1716from pyop2 import op2
2524import finat
2625
2726import firedrake
28- import firedrake .bcs
2927from firedrake import tsfc_interface , utils , functionspaceimpl
3028from firedrake .ufl_expr import Argument , Coargument , action , adjoint as expr_adjoint
3129from firedrake .mesh import MissingPointsBehaviour , VertexOnlyMeshMissingPointsError , VertexOnlyMeshTopology
3230from firedrake .petsc import PETSc
3331from firedrake .halo import _get_mtype as get_dat_mpi_type
34- from firedrake .cofunction import Cofunction
3532from mpi4py import MPI
3633
3734from pyadjoint import stop_annotating , no_annotations
4845
4946class Interpolate (ufl .Interpolate ):
5047
51- def __init__ (self , expr , v ,
48+ def __init__ (self , expr , V ,
5249 subset = None ,
5350 access = None ,
5451 allow_missing_dofs = False ,
@@ -60,7 +57,7 @@ def __init__(self, expr, v,
6057 ----------
6158 expr : ufl.core.expr.Expr or ufl.BaseForm
6259 The UFL expression to interpolate.
63- v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
60+ V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
6461 The function space to interpolate into or the coargument defined
6562 on the dual of the function space to interpolate into.
6663 subset : pyop2.types.set.Subset
@@ -95,20 +92,18 @@ def __init__(self, expr, v,
9592 between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
9693 and reduce operations.
9794 """
98- # Check function space
9995 expr = ufl .as_ufl (expr )
100- if isinstance (v , functionspaceimpl .WithGeometry ):
101- expr_args = extract_arguments (expr )
102- is_adjoint = len (expr_args ) and expr_args [0 ].number () == 0
103- v = Argument (v .dual (), 1 if is_adjoint else 0 )
96+ if isinstance (V , functionspaceimpl .WithGeometry ):
97+ expr_args = expr .arguments ()[1 :] if isinstance (expr , ufl .BaseForm ) else extract_arguments (expr )
98+ expr_arg_numbers = {arg .number () for arg in expr_args }
99+ # Need to create a Firedrake Argument so that it has a .function_space() method
100+ V = Argument (V .dual (), 1 if expr_arg_numbers == {0 } else 0 )
104101
105- V = v .arguments ()[0 ].function_space ()
106- if len ( expr .ufl_shape ) != len ( V . value_shape ) :
107- raise RuntimeError ( f'Rank mismatch: Expression rank { len ( expr .ufl_shape ) } , FunctionSpace rank { len ( V . value_shape ) } ' )
102+ target_shape = V .arguments ()[0 ].function_space (). value_shape
103+ if expr .ufl_shape != target_shape :
104+ raise ValueError ( f"Shape mismatch: Expression shape { expr .ufl_shape } , FunctionSpace shape { target_shape } ." )
108105
109- if expr .ufl_shape != V .value_shape :
110- raise RuntimeError ('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}' )
111- super ().__init__ (expr , v )
106+ super ().__init__ (expr , V )
112107
113108 # -- Interpolate data (e.g. `subset` or `access`) -- #
114109 self .interp_data = {"subset" : subset ,
@@ -174,32 +169,10 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def
174169 reduction (hence using MIN will compute the MIN between the
175170 existing values and any new values).
176171 """
177- if isinstance (V , (Cofunction , Coargument )):
178- dual_arg = V
179- elif isinstance (V , ufl .BaseForm ):
180- rank = len (V .arguments ())
181- if rank == 1 :
182- dual_arg = V
183- else :
184- raise TypeError (f"Expected a one-form, provided form had { rank } arguments" )
185- elif isinstance (V , functionspaceimpl .WithGeometry ):
186- dual_arg = Coargument (V .dual (), 0 )
187- expr_args = extract_arguments (ufl .as_ufl (expr ))
188- if expr_args and expr_args [0 ].number () == 0 :
189- warnings .warn ("Passing argument numbered 0 in expression for forward interpolation is deprecated. "
190- "Use a TrialFunction in the expression." )
191- v , = expr_args
192- expr = replace (expr , {v : v .reconstruct (number = 1 )})
193- else :
194- raise TypeError (f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a { type (V ).__name__ } " )
195-
196- interp = Interpolate (expr , dual_arg ,
197- subset = subset , access = access ,
198- allow_missing_dofs = allow_missing_dofs ,
199- default_missing_val = default_missing_val ,
200- matfree = matfree )
201-
202- return interp
172+ return Interpolate (
173+ expr , V , subset = subset , access = access , allow_missing_dofs = allow_missing_dofs ,
174+ default_missing_val = default_missing_val , matfree = matfree
175+ )
203176
204177
205178class Interpolator (abc .ABC ):
@@ -528,7 +501,7 @@ def __init__(
528501
529502 from firedrake .assemble import assemble
530503 V_dest_vec = firedrake .VectorFunctionSpace (dest_mesh , ufl_scalar_element )
531- f_dest_node_coords = Interpolate (dest_mesh .coordinates , V_dest_vec )
504+ f_dest_node_coords = interpolate (dest_mesh .coordinates , V_dest_vec )
532505 f_dest_node_coords = assemble (f_dest_node_coords )
533506 dest_node_coords = f_dest_node_coords .dat .data_ro .reshape (- 1 , dest_mesh_gdim )
534507 try :
@@ -553,15 +526,15 @@ def __init__(
553526 else :
554527 fs_type = partial (firedrake .TensorFunctionSpace , shape = shape )
555528 P0DG_vom = fs_type (self .vom_dest_node_coords_in_src_mesh , "DG" , 0 )
556- self .point_eval_interpolate = Interpolate (self .expr_renumbered , P0DG_vom )
529+ self .point_eval_interpolate = interpolate (self .expr_renumbered , P0DG_vom )
557530 # The parallel decomposition of the nodes of V_dest in the DESTINATION
558531 # mesh (dest_mesh) is retrieved using the input_ordering attribute of the
559532 # VOM. This again is an interpolation operation, which, under the hood
560533 # is a PETSc SF reduce.
561534 P0DG_vom_i_o = fs_type (
562535 self .vom_dest_node_coords_in_src_mesh .input_ordering , "DG" , 0
563536 )
564- self .to_input_ordering_interpolate = Interpolate (
537+ self .to_input_ordering_interpolate = interpolate (
565538 firedrake .TrialFunction (P0DG_vom ), P0DG_vom_i_o
566539 )
567540 # The P0DG function outputted by the above interpolation has the
0 commit comments