1+ from functools import wraps
12import weakref
23from itertools import zip_longest
34
45from firedrake .petsc import PETSc
6+ from firedrake .function import Function
7+ from firedrake .cofunction import Cofunction
58from pyop2 .mpi import MPI , internal_comm
69
710__all__ = ("Ensemble" , )
811
912
13+ def _ensemble_mpi_dispatch (func ):
14+ """
15+ This wrapper checks if any arg or kwarg of the wrapped
16+ ensemble method is a Function or Cofunction, and if so
17+ it calls the specialised Firedrake implementation.
18+ Otherwise the standard mpi4py implementation is called.
19+ """
20+ @wraps (func )
21+ def _mpi_dispatch (self , * args , ** kwargs ):
22+ if any (isinstance (arg , (Function , Cofunction ))
23+ for arg in [* args , * kwargs .values ()]):
24+ return func (self , * args , ** kwargs )
25+ else :
26+ mpicall = getattr (
27+ self .ensemble_comm ,
28+ func .__name__ )
29+ return mpicall (* args , ** kwargs )
30+ return _mpi_dispatch
31+
32+
1033class Ensemble (object ):
1134 def __init__ (self , comm , M , ** kwargs ):
1235 """
@@ -79,6 +102,7 @@ def _check_function(self, f, g=None):
79102 raise ValueError ("Mismatching function spaces for functions" )
80103
81104 @PETSc .Log .EventDecorator ()
105+ @_ensemble_mpi_dispatch
82106 def allreduce (self , f , f_reduced , op = MPI .SUM ):
83107 """
84108 Allreduce a function f into f_reduced over ``ensemble_comm`` .
@@ -96,6 +120,7 @@ def allreduce(self, f, f_reduced, op=MPI.SUM):
96120 return f_reduced
97121
98122 @PETSc .Log .EventDecorator ()
123+ @_ensemble_mpi_dispatch
99124 def iallreduce (self , f , f_reduced , op = MPI .SUM ):
100125 """
101126 Allreduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` .
@@ -113,6 +138,7 @@ def iallreduce(self, f, f_reduced, op=MPI.SUM):
113138 for fdat , rdat in zip (f .dat , f_reduced .dat )]
114139
115140 @PETSc .Log .EventDecorator ()
141+ @_ensemble_mpi_dispatch
116142 def reduce (self , f , f_reduced , op = MPI .SUM , root = 0 ):
117143 """
118144 Reduce a function f into f_reduced over ``ensemble_comm`` to rank root
@@ -136,6 +162,7 @@ def reduce(self, f, f_reduced, op=MPI.SUM, root=0):
136162 return f_reduced
137163
138164 @PETSc .Log .EventDecorator ()
165+ @_ensemble_mpi_dispatch
139166 def ireduce (self , f , f_reduced , op = MPI .SUM , root = 0 ):
140167 """
141168 Reduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` to rank root
@@ -154,6 +181,7 @@ def ireduce(self, f, f_reduced, op=MPI.SUM, root=0):
154181 for fdat , rdat in zip (f .dat , f_reduced .dat )]
155182
156183 @PETSc .Log .EventDecorator ()
184+ @_ensemble_mpi_dispatch
157185 def bcast (self , f , root = 0 ):
158186 """
159187 Broadcast a function f over ``ensemble_comm`` from rank root
@@ -169,6 +197,7 @@ def bcast(self, f, root=0):
169197 return f
170198
171199 @PETSc .Log .EventDecorator ()
200+ @_ensemble_mpi_dispatch
172201 def ibcast (self , f , root = 0 ):
173202 """
174203 Broadcast (non-blocking) a function f over ``ensemble_comm`` from rank root
@@ -184,6 +213,7 @@ def ibcast(self, f, root=0):
184213 for dat in f .dat ]
185214
186215 @PETSc .Log .EventDecorator ()
216+ @_ensemble_mpi_dispatch
187217 def send (self , f , dest , tag = 0 ):
188218 """
189219 Send (blocking) a function f over ``ensemble_comm`` to another
@@ -199,6 +229,7 @@ def send(self, f, dest, tag=0):
199229 self ._ensemble_comm .Send (dat .data_ro , dest = dest , tag = tag )
200230
201231 @PETSc .Log .EventDecorator ()
232+ @_ensemble_mpi_dispatch
202233 def recv (self , f , source = MPI .ANY_SOURCE , tag = MPI .ANY_TAG , statuses = None ):
203234 """
204235 Receive (blocking) a function f over ``ensemble_comm`` from
@@ -215,8 +246,10 @@ def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None):
215246 raise ValueError ("Need to provide enough status objects for all parts of the Function" )
216247 for dat , status in zip_longest (f .dat , statuses or (), fillvalue = None ):
217248 self ._ensemble_comm .Recv (dat .data , source = source , tag = tag , status = status )
249+ return f
218250
219251 @PETSc .Log .EventDecorator ()
252+ @_ensemble_mpi_dispatch
220253 def isend (self , f , dest , tag = 0 ):
221254 """
222255 Send (non-blocking) a function f over ``ensemble_comm`` to another
@@ -233,6 +266,7 @@ def isend(self, f, dest, tag=0):
233266 for dat in f .dat ]
234267
235268 @PETSc .Log .EventDecorator ()
269+ @_ensemble_mpi_dispatch
236270 def irecv (self , f , source = MPI .ANY_SOURCE , tag = MPI .ANY_TAG ):
237271 """
238272 Receive (non-blocking) a function f over ``ensemble_comm`` from
@@ -249,6 +283,7 @@ def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG):
249283 for dat in f .dat ]
250284
251285 @PETSc .Log .EventDecorator ()
286+ @_ensemble_mpi_dispatch
252287 def sendrecv (self , fsend , dest , sendtag = 0 , frecv = None , source = MPI .ANY_SOURCE , recvtag = MPI .ANY_TAG , status = None ):
253288 """
254289 Send (blocking) a function fsend and receive a function frecv over ``ensemble_comm`` to another
@@ -270,8 +305,10 @@ def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, re
270305 self ._ensemble_comm .Sendrecv (sendvec , dest , sendtag = sendtag ,
271306 recvbuf = recvvec , source = source , recvtag = recvtag ,
272307 status = status )
308+ return frecv
273309
274310 @PETSc .Log .EventDecorator ()
311+ @_ensemble_mpi_dispatch
275312 def isendrecv (self , fsend , dest , sendtag = 0 , frecv = None , source = MPI .ANY_SOURCE , recvtag = MPI .ANY_TAG ):
276313 """
277314 Send a function fsend and receive a function frecv over ``ensemble_comm`` to another
0 commit comments