Skip to content

Commit

Permalink
Fix up things in the ensemble tests (#3725)
Browse files Browse the repository at this point in the history
* Fix parallel assertion

* Apply suggestions from code review

Co-authored-by: Josh Hope-Collins <joshua.hope-collins13@imperial.ac.uk>

---------

Co-authored-by: Josh Hope-Collins <joshua.hope-collins13@imperial.ac.uk>
  • Loading branch information
2 people authored and pbrubeck committed Oct 8, 2024
1 parent 83ee1fe commit 1c32625
Showing 1 changed file with 53 additions and 17 deletions.
70 changes: 53 additions & 17 deletions tests/regression/test_ensembleparallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@
pytest.param(False, id="nonblocking")]


def parallel_assert(assertion, subset=None, msg=""):
""" Move this functionality to pytest-mpi
"""
if subset:
if MPI.COMM_WORLD.rank in subset:
evaluation = assertion()
else:
evaluation = True
else:
evaluation = assertion()
all_results = MPI.COMM_WORLD.allgather(evaluation)
if not all(all_results):
raise AssertionError(
"Parallel assertion failed on ranks: "
f"{[ii for ii, b in enumerate(all_results) if not b]}\n" + msg
)


# unique profile on each mixed function component on each ensemble rank
def function_profile(x, y, rank, cpt):
return sin(cpt + (rank+1)*pi*x)*cos(cpt + (rank+1)*pi*y)
Expand Down Expand Up @@ -102,7 +120,7 @@ def test_ensemble_allreduce(ensemble, mesh, W, urank, urank_sum, blocking):
requests = ensemble.iallreduce(urank, u_reduce)
MPI.Request.Waitall(requests)

assert errornorm(urank_sum, u_reduce) < 1e-12
parallel_assert(lambda: errornorm(urank_sum, u_reduce) < 1e-12)


@pytest.mark.parallel(nprocs=2)
Expand Down Expand Up @@ -154,7 +172,7 @@ def test_comm_manager_allreduce(blocking):
f5 = Function(V5)

with f4.dat.vec_ro as v4, f5.dat.vec_ro as v5:
assert v4.getSizes() == v5.getSizes()
parallel_assert(lambda: v4.getSizes() == v5.getSizes())

with pytest.raises(ValueError):
allreduce(f4, f5)
Expand Down Expand Up @@ -182,10 +200,19 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
MPI.Request.Waitall(requests)

# only u_reduce on rank root should be modified
if ensemble.ensemble_comm.rank == root:
assert errornorm(urank_sum, u_reduce) < 1e-12
else:
assert errornorm(Function(W).assign(10), u_reduce) < 1e-12
error = errornorm(urank_sum, u_reduce)
root_ranks = {ii + root*ensemble.comm.size for ii in range(ensemble.comm.size)}
parallel_assert(
lambda: error < 1e-12,
subset=root_ranks,
msg=f"{error = :.5f}"
)
error = errornorm(Function(W).assign(10), u_reduce)
parallel_assert(
lambda: error < 1e-12,
subset={range(COMM_WORLD.size)} - root_ranks,
msg=f"{error = :.5f}"
)

# check that u_reduce dat vector is still synchronised
spatial_rank = ensemble.comm.rank
Expand All @@ -194,7 +221,9 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
with u_reduce.dat.vec as v:
states[spatial_rank] = v.stateGet()
ensemble.comm.Allgather(MPI.IN_PLACE, states)
assert len(set(states)) == 1
parallel_assert(
lambda: len(set(states)) == 1,
)


@pytest.mark.parallel(nprocs=2)
Expand Down Expand Up @@ -245,7 +274,7 @@ def test_comm_manager_reduce(blocking):
f5 = Function(V5)

with f4.dat.vec_ro as v4, f5.dat.vec_ro as v5:
assert v4.getSizes() == v5.getSizes()
parallel_assert(lambda: v4.getSizes() == v5.getSizes())

with pytest.raises(ValueError):
reduction(f4, f5)
Expand Down Expand Up @@ -273,7 +302,7 @@ def test_ensemble_bcast(ensemble, mesh, W, urank, root, blocking):
# broadcasted function
u_correct = unique_function(mesh, root, W)

assert errornorm(u_correct, urank) < 1e-12
parallel_assert(lambda: errornorm(u_correct, urank) < 1e-12)


@pytest.mark.parallel(nprocs=6)
Expand All @@ -297,22 +326,29 @@ def test_send_and_recv(ensemble, mesh, W, blocking):
if ensemble_rank == rank0:
send_requests = send(usend, dest=rank1, tag=rank0)
recv_requests = recv(urecv, source=rank1, tag=rank1)

if not blocking:
MPI.Request.waitall(send_requests)
MPI.Request.waitall(recv_requests)

assert errornorm(urecv, usend) < 1e-12

error = errornorm(urecv, usend)
elif ensemble_rank == rank1:
recv_requests = recv(urecv, source=rank0, tag=rank0)
send_requests = send(usend, dest=rank0, tag=rank1)

if not blocking:
MPI.Request.waitall(send_requests)
MPI.Request.waitall(recv_requests)
error = errornorm(urecv, usend)
else:
error = 0

assert errornorm(urecv, usend) < 1e-12
# Test send/recv between first two spatial comms
# ie: ensemble.ensemble_comm.rank == 0 and 1
root_ranks = {ii + rank0*ensemble.comm.size for ii in range(ensemble.comm.size)}
root_ranks |= {ii + rank1*ensemble.comm.size for ii in range(ensemble.comm.size)}
parallel_assert(
lambda: error < 1e-12,
subset=root_ranks,
msg=f"{error = :.5f}"
)


@pytest.mark.parallel(nprocs=6)
Expand All @@ -339,7 +375,7 @@ def test_sendrecv(ensemble, mesh, W, urank, blocking):
if not blocking:
MPI.Request.Waitall(requests)

assert errornorm(urecv, u_expect) < 1e-12
parallel_assert(lambda: errornorm(urecv, u_expect) < 1e-12)


@pytest.mark.parallel(nprocs=6)
Expand Down Expand Up @@ -378,4 +414,4 @@ def test_ensemble_solvers(ensemble, W, urank, urank_sum):
usum = Function(W)
ensemble.allreduce(u_separate, usum)

assert errornorm(u_combined, usum) < 1e-8
parallel_assert(lambda: errornorm(u_combined, usum) < 1e-8)

0 comments on commit 1c32625

Please sign in to comment.