Skip to content

Commit 6f1737c

Browse files
Ensemble methods dispatch to MPI methods for non-Firedrake types (#4580)
Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
1 parent 80f9a50 commit 6f1737c

File tree

3 files changed

+171
-4
lines changed

3 files changed

+171
-4
lines changed

firedrake/ensemble/ensemble.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
1+
from functools import wraps
12
import weakref
23
from itertools import zip_longest
34

45
from firedrake.petsc import PETSc
6+
from firedrake.function import Function
7+
from firedrake.cofunction import Cofunction
58
from pyop2.mpi import MPI, internal_comm
69

710
__all__ = ("Ensemble", )
811

912

13+
def _ensemble_mpi_dispatch(func):
14+
"""
15+
This wrapper checks if any arg or kwarg of the wrapped
16+
ensemble method is a Function or Cofunction, and if so
17+
it calls the specialised Firedrake implementation.
18+
Otherwise the standard mpi4py implementation is called.
19+
"""
20+
@wraps(func)
21+
def _mpi_dispatch(self, *args, **kwargs):
22+
if any(isinstance(arg, (Function, Cofunction))
23+
for arg in [*args, *kwargs.values()]):
24+
return func(self, *args, **kwargs)
25+
else:
26+
mpicall = getattr(
27+
self.ensemble_comm,
28+
func.__name__)
29+
return mpicall(*args, **kwargs)
30+
return _mpi_dispatch
31+
32+
1033
class Ensemble(object):
1134
def __init__(self, comm, M, **kwargs):
1235
"""
@@ -79,6 +102,7 @@ def _check_function(self, f, g=None):
79102
raise ValueError("Mismatching function spaces for functions")
80103

81104
@PETSc.Log.EventDecorator()
105+
@_ensemble_mpi_dispatch
82106
def allreduce(self, f, f_reduced, op=MPI.SUM):
83107
"""
84108
Allreduce a function f into f_reduced over ``ensemble_comm`` .
@@ -96,6 +120,7 @@ def allreduce(self, f, f_reduced, op=MPI.SUM):
96120
return f_reduced
97121

98122
@PETSc.Log.EventDecorator()
123+
@_ensemble_mpi_dispatch
99124
def iallreduce(self, f, f_reduced, op=MPI.SUM):
100125
"""
101126
Allreduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` .
@@ -113,6 +138,7 @@ def iallreduce(self, f, f_reduced, op=MPI.SUM):
113138
for fdat, rdat in zip(f.dat, f_reduced.dat)]
114139

115140
@PETSc.Log.EventDecorator()
141+
@_ensemble_mpi_dispatch
116142
def reduce(self, f, f_reduced, op=MPI.SUM, root=0):
117143
"""
118144
Reduce a function f into f_reduced over ``ensemble_comm`` to rank root
@@ -136,6 +162,7 @@ def reduce(self, f, f_reduced, op=MPI.SUM, root=0):
136162
return f_reduced
137163

138164
@PETSc.Log.EventDecorator()
165+
@_ensemble_mpi_dispatch
139166
def ireduce(self, f, f_reduced, op=MPI.SUM, root=0):
140167
"""
141168
Reduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` to rank root
@@ -154,6 +181,7 @@ def ireduce(self, f, f_reduced, op=MPI.SUM, root=0):
154181
for fdat, rdat in zip(f.dat, f_reduced.dat)]
155182

156183
@PETSc.Log.EventDecorator()
184+
@_ensemble_mpi_dispatch
157185
def bcast(self, f, root=0):
158186
"""
159187
Broadcast a function f over ``ensemble_comm`` from rank root
@@ -169,6 +197,7 @@ def bcast(self, f, root=0):
169197
return f
170198

171199
@PETSc.Log.EventDecorator()
200+
@_ensemble_mpi_dispatch
172201
def ibcast(self, f, root=0):
173202
"""
174203
Broadcast (non-blocking) a function f over ``ensemble_comm`` from rank root
@@ -184,6 +213,7 @@ def ibcast(self, f, root=0):
184213
for dat in f.dat]
185214

186215
@PETSc.Log.EventDecorator()
216+
@_ensemble_mpi_dispatch
187217
def send(self, f, dest, tag=0):
188218
"""
189219
Send (blocking) a function f over ``ensemble_comm`` to another
@@ -199,6 +229,7 @@ def send(self, f, dest, tag=0):
199229
self._ensemble_comm.Send(dat.data_ro, dest=dest, tag=tag)
200230

201231
@PETSc.Log.EventDecorator()
232+
@_ensemble_mpi_dispatch
202233
def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None):
203234
"""
204235
Receive (blocking) a function f over ``ensemble_comm`` from
@@ -215,8 +246,10 @@ def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None):
215246
raise ValueError("Need to provide enough status objects for all parts of the Function")
216247
for dat, status in zip_longest(f.dat, statuses or (), fillvalue=None):
217248
self._ensemble_comm.Recv(dat.data, source=source, tag=tag, status=status)
249+
return f
218250

219251
@PETSc.Log.EventDecorator()
252+
@_ensemble_mpi_dispatch
220253
def isend(self, f, dest, tag=0):
221254
"""
222255
Send (non-blocking) a function f over ``ensemble_comm`` to another
@@ -233,6 +266,7 @@ def isend(self, f, dest, tag=0):
233266
for dat in f.dat]
234267

235268
@PETSc.Log.EventDecorator()
269+
@_ensemble_mpi_dispatch
236270
def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG):
237271
"""
238272
Receive (non-blocking) a function f over ``ensemble_comm`` from
@@ -249,6 +283,7 @@ def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG):
249283
for dat in f.dat]
250284

251285
@PETSc.Log.EventDecorator()
286+
@_ensemble_mpi_dispatch
252287
def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG, status=None):
253288
"""
254289
Send (blocking) a function fsend and receive a function frecv over ``ensemble_comm`` to another
@@ -270,8 +305,10 @@ def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, re
270305
self._ensemble_comm.Sendrecv(sendvec, dest, sendtag=sendtag,
271306
recvbuf=recvvec, source=source, recvtag=recvtag,
272307
status=status)
308+
return frecv
273309

274310
@PETSc.Log.EventDecorator()
311+
@_ensemble_mpi_dispatch
275312
def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG):
276313
"""
277314
Send a function fsend and receive a function frecv over ``ensemble_comm`` to another

tests/firedrake/ensemble/test_ensemble.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
import pytest
44
from pytest_mpi.parallel_assert import parallel_assert
55

6-
from operator import mul
7-
from functools import reduce
8-
96

107
max_ncpts = 2
118

@@ -60,7 +57,7 @@ def W(request, mesh):
6057
if COMM_WORLD.size == 1:
6158
return
6259
V = FunctionSpace(mesh, "CG", 1)
63-
return reduce(mul, [V for _ in range(request.param)])
60+
return MixedFunctionSpace([V for _ in range(request.param)])
6461

6562

6663
# initialise unique function on each rank
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from firedrake import *
2+
import pytest
3+
from pytest_mpi.parallel_assert import parallel_assert
4+
5+
6+
min_root = 0
7+
max_root = 1
8+
9+
roots = []
10+
roots.extend([pytest.param(None, id="root_none")])
11+
roots.extend([pytest.param(i, id=f"root_{i}")
12+
for i in range(min_root, max_root + 1)])
13+
14+
blocking = [
15+
pytest.param(True, id="blocking"),
16+
pytest.param(False, id="nonblocking")
17+
]
18+
19+
sendrecv_pairs = [
20+
pytest.param((0, 1), id="ranks01"),
21+
pytest.param((1, 2), id="ranks12"),
22+
pytest.param((2, 0), id="ranks20")
23+
]
24+
25+
26+
@pytest.fixture(scope="module")
27+
def ensemble():
28+
if COMM_WORLD.size == 1:
29+
return
30+
return Ensemble(COMM_WORLD, 1)
31+
32+
33+
@pytest.mark.parallel(nprocs=2)
34+
def test_ensemble_allreduce(ensemble):
35+
rank = ensemble.ensemble_rank
36+
result = ensemble.allreduce(rank+1)
37+
expected = sum([r+1 for r in range(ensemble.ensemble_size)])
38+
parallel_assert(
39+
result == expected,
40+
msg=f"{result=} does not match {expected=}")
41+
42+
43+
@pytest.mark.parallel(nprocs=2)
44+
@pytest.mark.parametrize("root", roots)
45+
def test_ensemble_reduce(ensemble, root):
46+
rank = ensemble.ensemble_rank
47+
48+
# check default root=0 works
49+
if root is None:
50+
result = ensemble.reduce(rank+1)
51+
root = 0
52+
else:
53+
result = ensemble.reduce(rank+1, root=root)
54+
55+
expected = sum([r+1 for r in range(ensemble.ensemble_size)])
56+
57+
parallel_assert(
58+
result == expected,
59+
participating=(rank == root),
60+
msg=f"{result=} does not match {expected=} on rank {root=}"
61+
)
62+
parallel_assert(
63+
result is None,
64+
participating=(rank != root),
65+
msg=f"Unexpected {result=} on non-root rank"
66+
)
67+
68+
69+
@pytest.mark.parallel(nprocs=2)
70+
@pytest.mark.parametrize("root", roots)
71+
def test_ensemble_bcast(ensemble, root):
72+
rank = ensemble.ensemble_rank
73+
74+
# check default root=0 works
75+
if root is None:
76+
result = ensemble.bcast(rank+1)
77+
root = 0
78+
else:
79+
result = ensemble.bcast(rank+1, root=root)
80+
81+
expected = root + 1
82+
83+
parallel_assert(result == expected)
84+
85+
86+
@pytest.mark.parallel(nprocs=3)
87+
@pytest.mark.parametrize("ranks", sendrecv_pairs)
88+
def test_send_and_recv(ensemble, ranks):
89+
rank = ensemble.ensemble_rank
90+
91+
rank0, rank1 = ranks
92+
93+
send_data = rank + 1
94+
95+
if rank == rank0:
96+
recv_expected = rank1 + 1
97+
98+
ensemble.send(send_data, dest=rank1, tag=rank0)
99+
recv_data = ensemble.recv(source=rank1, tag=rank1)
100+
101+
elif rank == rank1:
102+
recv_expected = rank0 + 1
103+
104+
recv_data = ensemble.recv(source=rank0, tag=rank0)
105+
ensemble.send(send_data, dest=rank0, tag=rank1)
106+
107+
else:
108+
recv_expected = None
109+
recv_data = None
110+
111+
# Test send/recv between first two spatial comms
112+
# ie: ensemble.ensemble_comm.rank == 0 and 1
113+
parallel_assert(
114+
recv_data == recv_expected,
115+
participating=rank in (rank0, rank1),
116+
)
117+
118+
119+
@pytest.mark.parallel(nprocs=3)
120+
def test_sendrecv(ensemble):
121+
rank = ensemble.ensemble_rank
122+
size = ensemble.ensemble_size
123+
src_rank = (rank - 1) % size
124+
dst_rank = (rank + 1) % size
125+
126+
send_data = rank + 1
127+
recv_expected = src_rank + 1
128+
129+
recv_result = ensemble.sendrecv(
130+
send_data, dst_rank, sendtag=rank,
131+
source=src_rank, recvtag=src_rank)
132+
133+
parallel_assert(recv_result == recv_expected)

0 commit comments

Comments
 (0)