Skip to content

Commit

Permalink
Change data set function
Browse files Browse the repository at this point in the history
  • Loading branch information
torressa committed Jul 16, 2024
1 parent 3e878d4 commit edb81af
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 75 deletions.
6 changes: 4 additions & 2 deletions src/gurobi_optimods/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,16 @@ def _convert_pandas_to_digraph(
capacity=True,
cost=True,
demand=True,
digraph=nx.DiGraph,
use_multigraph=False,
):
"""
Convert from a pandas DataFrame to a networkx.MultiDiGraph with the appropriate
attributes. For edges: `capacity`, and `cost`. For nodes: `demand`.
"""
graph_type = nx.MultiDiGraph if use_multigraph else nx.DiGraph

G = nx.from_pandas_edgelist(
edge_data.reset_index(), create_using=digraph(), edge_attr=True
edge_data.reset_index(), create_using=graph_type(), edge_attr=True
)
if demand:
for i, d in node_data.iterrows():
Expand Down
10 changes: 5 additions & 5 deletions src/gurobi_optimods/min_cost_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def min_cost_flow_pandas(

source_label, target_label = arc_data.index.names

multigraph = False
use_multigraph = False
# This is a workaround for duplicate entries being disallowed in gurobipy_pandas
if arc_data.index.has_duplicates:
arc_data = arc_data.reset_index()
multigraph = True
use_multigraph = True

arc_df = arc_data.gppd.add_vars(model, ub="capacity", obj="cost", name="flow")

Expand All @@ -91,7 +91,7 @@ def min_cost_flow_pandas(
if model.Status in [GRB.INFEASIBLE, GRB.INF_OR_UNBD]:
raise ValueError("Unsatisfiable flows")

if multigraph:
if use_multigraph:
# Repair index that was reset above
arc_df = arc_df.set_index([source_label, target_label])
return model.ObjVal, arc_df["flow"].gppd.X
Expand Down Expand Up @@ -182,7 +182,7 @@ def min_cost_flow_networkx(G, *, create_env):
f"Solving min-cost flow with {len(G.nodes)} nodes and {len(G.edges)} edges"
)
with create_env() as env, gp.Model(env=env) as model:
multigraph = isinstance(G, nx.MultiGraph)
use_multigraph = isinstance(G, nx.MultiGraph)

G = nx.MultiDiGraph(G)

Expand Down Expand Up @@ -226,7 +226,7 @@ def min_cost_flow_networkx(G, *, create_env):
raise ValueError("Unsatisfiable flows")

# Create a new Graph with selected edges in the matching
resulting_flow = nx.MultiDiGraph() if multigraph else nx.DiGraph()
resulting_flow = nx.MultiDiGraph() if use_multigraph else nx.DiGraph()
resulting_flow.add_nodes_from(nodes)
resulting_flow.add_edges_from(
[(edge[0], edge[1], {"flow": v.X}) for edge, v in x.items() if v.X > 0.1]
Expand Down
126 changes: 63 additions & 63 deletions tests/test_max_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ def test_pandas(self):
obj, sol = max_flow(edge_data, 0, 5)
sol = sol[sol > 0]
self.assertEqual(obj, self.expected_max_flow)
candidate = {
(0, 1): 1.0,
(0, 2): 2.0,
(1, 3): 1.0,
(2, 3): 1.0,
(2, 4): 1.0,
(3, 5): 2.0,
(4, 5): 1.0,
}
candidate2 = {
(0, 1): 1.0,
(0, 2): 2.0,
(1, 3): 1.0,
(2, 4): 2.0,
(3, 5): 1.0,
(4, 5): 2.0,
}
candidate = [
((0, 1), 1.0),
((0, 2), 2.0),
((1, 3), 1.0),
((2, 3), 1.0),
((2, 4), 1.0),
((3, 5), 2.0),
((4, 5), 1.0),
]
candidate2 = [
((0, 1), 1.0),
((0, 2), 2.0),
((1, 3), 1.0),
((2, 4), 2.0),
((3, 5), 1.0),
((4, 5), 2.0),
]
self.assertTrue(check_solution_pandas(sol, [candidate, candidate2]))

def test_empty_pandas(self):
Expand Down Expand Up @@ -84,23 +84,23 @@ def test_networkx(self):
G = datasets.simple_graph_networkx()
obj, sol = max_flow(G, 0, 5)
self.assertEqual(obj, self.expected_max_flow)
candidate = {
(0, 1): {"flow": 1},
(0, 2): {"flow": 2},
(1, 3): {"flow": 1},
(2, 4): {"flow": 2},
(3, 5): {"flow": 1},
(4, 5): {"flow": 2},
}
candidate2 = {
(0, 1): {"flow": 1.0},
(0, 2): {"flow": 2.0},
(1, 3): {"flow": 1.0},
(2, 3): {"flow": 1.0},
(2, 4): {"flow": 1.0},
(3, 5): {"flow": 2.0},
(4, 5): {"flow": 1.0},
}
candidate = [
((0, 1), 1),
((0, 2), 2),
((1, 3), 1),
((2, 4), 2),
((3, 5), 1),
((4, 5), 2),
]
candidate2 = [
((0, 1), 1.0),
((0, 2), 2.0),
((1, 3), 1.0),
((2, 3), 1.0),
((2, 4), 1.0),
((3, 5), 2.0),
((4, 5), 1.0),
]
self.assertTrue(check_solution_networkx(sol, [candidate, candidate2]))

@unittest.skipIf(nx is None, "networkx is not installed")
Expand All @@ -120,16 +120,16 @@ def test_pandas(self):
obj, sol = max_flow(edge_data, 0, 4)
sol = sol[sol > 0]
self.assertEqual(obj, self.expected_max_flow)
candidate = {
(0, 1): 15.0,
(0, 2): 8.0,
(1, 3): 4.0,
(1, 2): 1.0,
(1, 4): 10.0,
(2, 3): 4.0,
(2, 4): 5.0,
(3, 4): 8.0,
}
candidate = [
((0, 1), 15.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 1.0),
((1, 4), 10.0),
((2, 3), 4.0),
((2, 4), 5.0),
((3, 4), 8.0),
]
self.assertTrue(check_solution_pandas(sol, [candidate]))

def test_scipy(self):
Expand All @@ -152,23 +152,23 @@ def test_networkx(self):
G = load_graph2_networkx()
obj, sol = max_flow(G, 0, 4)
self.assertEqual(obj, self.expected_max_flow)
candidate = {
(0, 1): {"flow": 15.0},
(0, 2): {"flow": 8.0},
(1, 3): {"flow": 4.0},
(1, 2): {"flow": 1.0},
(1, 4): {"flow": 10.0},
(2, 3): {"flow": 4.0},
(2, 4): {"flow": 5.0},
(3, 4): {"flow": 8.0},
}
candidate2 = {
(0, 1): {"flow": 15},
(0, 2): {"flow": 8},
(1, 2): {"flow": 1},
(1, 3): {"flow": 4},
(1, 4): {"flow": 10},
(2, 3): {"flow": 9},
(3, 4): {"flow": 13},
}
candidate = [
((0, 1), 15.0),
((0, 2), 8.0),
((1, 3), 4.0),
((1, 2), 1.0),
((1, 4), 10.0),
((2, 3), 4.0),
((2, 4), 5.0),
((3, 4), 8.0),
]
candidate2 = [
((0, 1), 15),
((0, 2), 8),
((1, 2), 1),
((1, 3), 4),
((1, 4), 10),
((2, 3), 9),
((3, 4), 13),
]
self.assertTrue(check_solution_networkx(sol, [candidate, candidate2]))
12 changes: 7 additions & 5 deletions tests/test_min_cost_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def load_graph2_pandas():
)


def load_graph2_networkx(digraph=nx.DiGraph):
def load_graph2_networkx():
edge_data, node_data = load_graph2_pandas()
return datasets._convert_pandas_to_digraph(edge_data, node_data, digraph=digraph)
return datasets._convert_pandas_to_digraph(edge_data, node_data)


def load_graph2_scipy():
Expand All @@ -80,9 +80,11 @@ def load_graph3_pandas():
)


def load_graph3_networkx(digraph=nx.DiGraph):
def load_graph3_networkx(use_multigraph):
edge_data, node_data = load_graph3_pandas()
return datasets._convert_pandas_to_digraph(edge_data, node_data, digraph=digraph)
return datasets._convert_pandas_to_digraph(
edge_data, node_data, use_multigraph=use_multigraph
)


class TestMinCostFlow(unittest.TestCase):
Expand Down Expand Up @@ -269,7 +271,7 @@ def test_pandas(self):

@unittest.skipIf(nx is None, "networkx is not installed")
def test_networkx(self):
G = load_graph3_networkx(digraph=nx.MultiDiGraph)
G = load_graph3_networkx(use_multigraph=True)
cost, sol = mcf.min_cost_flow_networkx(G)
self.assertEqual(cost, 49.0)
candidate = [
Expand Down

0 comments on commit edb81af

Please sign in to comment.