diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index 1455a1e0..703afebd 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -1,38 +1,25 @@ import numpy as np -def _sort_key(x): - return str(x) - - def check_solution_pandas(solution, candidates): # Checks whether the solution (`pd.Series`) matches any of the list of - # candidates (containing `dict`) - if any( - sorted(list(zip(solution.index.to_list(), solution.to_list())), key=_sort_key) - == sorted(c, key=_sort_key) - for c in candidates - ): - return True - return False + # candidates. Each candidate is a list of tuples ((i, j), v) tuples, + # compare with the solution in sorted order. + solution_list = sorted(solution.items()) + return any(solution_list == sorted(candidate) for candidate in candidates) def check_solution_scipy(solution, candidates): # Checks whether the solution (`sp.sparray`) matches any of the list of # candidates (containing `np.ndarray`) arr = solution.toarray() - if any(np.array_equal(arr, c) for c in candidates): - return True - return False + return any(np.array_equal(arr, c) for c in candidates) def check_solution_networkx(solution, candidates): # Checks whether the solution (`nx.DiGraph`) matches any of the list of # candidates (containing tuples dict `{(i, j): data}`) - sol_list = sorted( + solution_list = sorted( [((i, j), data["flow"]) for i, j, data in solution.edges(data=True)], - key=_sort_key, ) - if any(sol_list == sorted(c, key=_sort_key) for c in candidates): - return True - return False + return any(solution_list == sorted(candidate) for candidate in candidates)