Skip to content

Commit

Permalink
Fix some non-SPMD dat accesses in VOM tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Sep 12, 2024
1 parent a135d96 commit 6df0e3a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
7 changes: 7 additions & 0 deletions tests/vertexonly/test_vertex_only_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,14 @@ def test_input_ordering_missing_point():
# put data on the input ordering
P0DG_input_ordering = FunctionSpace(vm.input_ordering, "DG", 0)
data_input_ordering = Function(P0DG_input_ordering)

if vm.comm.rank == 0:
data_input_ordering.dat.data_wo[:] = data
# Accessing data_ro [*here] is collective, hence this redundant call
_ = len(data_input_ordering.dat.data_ro)
else:
data_input_ordering.dat.data_wo[:] = []
# [*here]
assert not len(data_input_ordering.dat.data_ro)

# shouldn't have any halos
Expand All @@ -348,6 +352,9 @@ def test_input_ordering_missing_point():
data_input_ordering.interpolate(data_on_vm)
if vm.comm.rank == 0:
assert np.allclose(data_input_ordering.dat.data_ro[0:3], 2*data[0:3])
# [*here]
assert np.allclose(data_input_ordering.dat.data_ro[3], data[3])
else:
assert not len(data_input_ordering.dat.data_ro)
# Accessing data_ro [*here] is collective, hence this redundant call
_ = len(data_input_ordering.dat.data_ro)
31 changes: 29 additions & 2 deletions tests/vertexonly/test_vertex_only_mesh_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,33 +144,54 @@ def verify_vertexonly_mesh(m, vm, inputvertexcoords, name):
total_cells = MPI.COMM_WORLD.allreduce(len(vm.coordinates.dat.data_ro), op=MPI.SUM)
total_in_bounds = MPI.COMM_WORLD.allreduce(len(in_bounds), op=MPI.SUM)
skip_in_bounds_checks = False
local_cells = len(vm.coordinates.dat.data_ro)
if total_cells != total_in_bounds:
assert MPI.COMM_WORLD.size > 1 # i.e. we're in parallel
assert total_cells < total_in_bounds # i.e. some points are duplicated
local_cells = len(vm.coordinates.dat.data_ro)
local_in_bounds = len(in_bounds)
if not local_cells == local_in_bounds and local_in_bounds > 0:
assert max(ref_cell_dists_l1) > 0.5*m.tolerance
# This assertion needs to happen in parallel!
assertion = (max(ref_cell_dists_l1) > 0.5*m.tolerance)
skip_in_bounds_checks = True
else:
assertion = True
else:
assertion = True
# FIXME: Replace with parallel assert when it's merged into pytest-mpi
assert min(MPI.COMM_WORLD.allgather([assertion]))

# Correct local coordinates (though not guaranteed to be in same order)
if not skip_in_bounds_checks:
# Correct local coordinates (though not guaranteed to be in same order)
# [*here]
np.allclose(np.sort(vm.coordinates.dat.data_ro), np.sort(inputvertexcoords[in_bounds]))
else:
# Accessing data_ro [*here] is collective, hence this redundant call
_ = len(vm.coordinates.dat.data_ro)
# Correct parent topology
assert vm._parent_mesh is m
assert vm.topology._parent_mesh is m.topology
# Correct generic cell properties
if not skip_in_bounds_checks:
# [*here]
assert vm.cell_closure.shape == (len(vm.coordinates.dat.data_ro_with_halos), 1)
else:
# Accessing data_ro [*here] is collective, hence this redundant call
_ = len(vm.coordinates.dat.data_ro_with_halos)
with pytest.raises(AttributeError):
vm.exterior_facets()
with pytest.raises(AttributeError):
vm.interior_facets()
with pytest.raises(AttributeError):
vm.cell_to_facets
if not skip_in_bounds_checks:
# [*here]
assert vm.num_cells() == vm.cell_closure.shape[0] == len(vm.coordinates.dat.data_ro_with_halos) == vm.cell_set.total_size
assert vm.cell_set.size == len(inputvertexcoords[in_bounds]) == len(vm.coordinates.dat.data_ro)
else:
# Accessing data_ro and data_ro_with_halos [*here] is collective, hence this redundant call
_ = len(vm.coordinates.dat.data_ro_with_halos)
_ = len(vm.coordinates.dat.data_ro)
assert vm.num_facets() == 0
assert vm.num_faces() == vm.num_entities(2) == 0
assert vm.num_edges() == vm.num_entities(1) == 0
Expand Down Expand Up @@ -257,11 +278,17 @@ def test_generate_cell_midpoints(parentmesh, redundant):
out_of_mesh_point = np.full((1, parentmesh.geometric_dimension()), np.inf)
for i in range(max_len):
if i < len(vm.coordinates.dat.data_ro):
# [*here]
cell_num = parentmesh.locate_cell(vm.coordinates.dat.data_ro[i])
else:
cell_num = parentmesh.locate_cell(out_of_mesh_point) # should return None
# Accessing data_ro [*here] is collective, hence this redundant call
_ = len(vm.coordinates.dat.data_ro)
if cell_num is not None:
assert (f.dat.data_ro[cell_num] == vm.coordinates.dat.data_ro[i]).all()
else:
_ = len(f.dat.data_ro)
_ = len(vm.coordinates.dat.data_ro)

# Have correct pyop2 labels as implied by cell set sizes
if parentmesh.extruded:
Expand Down

0 comments on commit 6df0e3a

Please sign in to comment.