Skip to content

Commit

Permalink
Simplify testing checks
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbowly committed Jul 17, 2024
1 parent edb81af commit 2536cbf
Showing 1 changed file with 7 additions and 20 deletions.
27 changes: 7 additions & 20 deletions tests/test_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,25 @@
import numpy as np


def _sort_key(x):
return str(x)


def check_solution_pandas(solution, candidates):
# Checks whether the solution (`pd.Series`) matches any of the list of
# candidates (containing `dict`)
if any(
sorted(list(zip(solution.index.to_list(), solution.to_list())), key=_sort_key)
== sorted(c, key=_sort_key)
for c in candidates
):
return True
return False
# candidates. Each candidate is a list of tuples ((i, j), v) tuples,
# compare with the solution in sorted order.
solution_list = sorted(solution.items())
return any(solution_list == sorted(candidate) for candidate in candidates)


def check_solution_scipy(solution, candidates):
# Checks whether the solution (`sp.sparray`) matches any of the list of
# candidates (containing `np.ndarray`)
arr = solution.toarray()
if any(np.array_equal(arr, c) for c in candidates):
return True
return False
return any(np.array_equal(arr, c) for c in candidates)


def check_solution_networkx(solution, candidates):
# Checks whether the solution (`nx.DiGraph`) matches any of the list of
# candidates (containing tuples dict `{(i, j): data}`)
sol_list = sorted(
solution_list = sorted(
[((i, j), data["flow"]) for i, j, data in solution.edges(data=True)],
key=_sort_key,
)
if any(sol_list == sorted(c, key=_sort_key) for c in candidates):
return True
return False
return any(solution_list == sorted(candidate) for candidate in candidates)

0 comments on commit 2536cbf

Please sign in to comment.