Skip to content

Commit e6f89a6

Browse files
imengusCaedenPHpre-commit-ci[bot]tianyizheng02
authored
Simplex algorithm (#8825)
* feat: added simplex.py * added docstrings * Update linear_programming/simplex.py Co-authored-by: Caeden Perelli-Harris <caedenperelliharris@gmail.com> * Update linear_programming/simplex.py Co-authored-by: Caeden Perelli-Harris <caedenperelliharris@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update linear_programming/simplex.py Co-authored-by: Caeden Perelli-Harris <caedenperelliharris@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ruff fix Co-authored by: CaedenPH <caedenperelliharris@gmail.com> * removed README to add in separate PR * Update linear_programming/simplex.py Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com> * Update linear_programming/simplex.py Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com> * fix class docstring * add comments --------- Co-authored-by: Caeden Perelli-Harris <caedenperelliharris@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com>
1 parent 4637986 commit e6f89a6

File tree

1 file changed

+311
-0
lines changed

1 file changed

+311
-0
lines changed

Diff for: linear_programming/simplex.py

+311
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
"""
2+
Python implementation of the simplex algorithm for solving linear programs in
3+
tabular form with
4+
- `>=`, `<=`, and `=` constraints and
5+
- each variable `x1, x2, ...>= 0`.
6+
7+
See https://gist.github.com/imengus/f9619a568f7da5bc74eaf20169a24d98 for how to
8+
convert linear programs to simplex tableaus, and the steps taken in the simplex
9+
algorithm.
10+
11+
Resources:
12+
https://en.wikipedia.org/wiki/Simplex_algorithm
13+
https://tinyurl.com/simplex4beginners
14+
"""
15+
from typing import Any
16+
17+
import numpy as np
18+
19+
20+
class Tableau:
21+
"""Operate on simplex tableaus
22+
23+
>>> t = Tableau(np.array([[-1,-1,0,0,-1],[1,3,1,0,4],[3,1,0,1,4.]]), 2)
24+
Traceback (most recent call last):
25+
...
26+
ValueError: RHS must be > 0
27+
"""
28+
29+
def __init__(self, tableau: np.ndarray, n_vars: int) -> None:
30+
# Check if RHS is negative
31+
if np.any(tableau[:, -1], where=tableau[:, -1] < 0):
32+
raise ValueError("RHS must be > 0")
33+
34+
self.tableau = tableau
35+
self.n_rows, _ = tableau.shape
36+
37+
# Number of decision variables x1, x2, x3...
38+
self.n_vars = n_vars
39+
40+
# Number of artificial variables to be minimised
41+
self.n_art_vars = len(np.where(tableau[self.n_vars : -1] == -1)[0])
42+
43+
# 2 if there are >= or == constraints (nonstandard), 1 otherwise (std)
44+
self.n_stages = (self.n_art_vars > 0) + 1
45+
46+
# Number of slack variables added to make inequalities into equalities
47+
self.n_slack = self.n_rows - self.n_stages
48+
49+
# Objectives for each stage
50+
self.objectives = ["max"]
51+
52+
# In two stage simplex, first minimise then maximise
53+
if self.n_art_vars:
54+
self.objectives.append("min")
55+
56+
self.col_titles = [""]
57+
58+
# Index of current pivot row and column
59+
self.row_idx = None
60+
self.col_idx = None
61+
62+
# Does objective row only contain (non)-negative values?
63+
self.stop_iter = False
64+
65+
@staticmethod
66+
def generate_col_titles(*args: int) -> list[str]:
67+
"""Generate column titles for tableau of specific dimensions
68+
69+
>>> Tableau.generate_col_titles(2, 3, 1)
70+
['x1', 'x2', 's1', 's2', 's3', 'a1', 'RHS']
71+
72+
>>> Tableau.generate_col_titles()
73+
Traceback (most recent call last):
74+
...
75+
ValueError: Must provide n_vars, n_slack, and n_art_vars
76+
>>> Tableau.generate_col_titles(-2, 3, 1)
77+
Traceback (most recent call last):
78+
...
79+
ValueError: All arguments must be non-negative integers
80+
"""
81+
if len(args) != 3:
82+
raise ValueError("Must provide n_vars, n_slack, and n_art_vars")
83+
84+
if not all(x >= 0 and isinstance(x, int) for x in args):
85+
raise ValueError("All arguments must be non-negative integers")
86+
87+
# decision | slack | artificial
88+
string_starts = ["x", "s", "a"]
89+
titles = []
90+
for i in range(3):
91+
for j in range(args[i]):
92+
titles.append(string_starts[i] + str(j + 1))
93+
titles.append("RHS")
94+
return titles
95+
96+
def find_pivot(self, tableau: np.ndarray) -> tuple[Any, Any]:
97+
"""Finds the pivot row and column.
98+
>>> t = Tableau(np.array([[-2,1,0,0,0], [3,1,1,0,6], [1,2,0,1,7.]]), 2)
99+
>>> t.find_pivot(t.tableau)
100+
(1, 0)
101+
"""
102+
objective = self.objectives[-1]
103+
104+
# Find entries of highest magnitude in objective rows
105+
sign = (objective == "min") - (objective == "max")
106+
col_idx = np.argmax(sign * tableau[0, : self.n_vars])
107+
108+
# Choice is only valid if below 0 for maximise, and above for minimise
109+
if sign * self.tableau[0, col_idx] <= 0:
110+
self.stop_iter = True
111+
return 0, 0
112+
113+
# Pivot row is chosen as having the lowest quotient when elements of
114+
# the pivot column divide the right-hand side
115+
116+
# Slice excluding the objective rows
117+
s = slice(self.n_stages, self.n_rows)
118+
119+
# RHS
120+
dividend = tableau[s, -1]
121+
122+
# Elements of pivot column within slice
123+
divisor = tableau[s, col_idx]
124+
125+
# Array filled with nans
126+
nans = np.full(self.n_rows - self.n_stages, np.nan)
127+
128+
# If element in pivot column is greater than zeron_stages, return
129+
# quotient or nan otherwise
130+
quotients = np.divide(dividend, divisor, out=nans, where=divisor > 0)
131+
132+
# Arg of minimum quotient excluding the nan values. n_stages is added
133+
# to compensate for earlier exclusion of objective columns
134+
row_idx = np.nanargmin(quotients) + self.n_stages
135+
return row_idx, col_idx
136+
137+
def pivot(self, tableau: np.ndarray, row_idx: int, col_idx: int) -> np.ndarray:
138+
"""Pivots on value on the intersection of pivot row and column.
139+
140+
>>> t = Tableau(np.array([[-2,-3,0,0,0],[1,3,1,0,4],[3,1,0,1,4.]]), 2)
141+
>>> t.pivot(t.tableau, 1, 0).tolist()
142+
... # doctest: +NORMALIZE_WHITESPACE
143+
[[0.0, 3.0, 2.0, 0.0, 8.0],
144+
[1.0, 3.0, 1.0, 0.0, 4.0],
145+
[0.0, -8.0, -3.0, 1.0, -8.0]]
146+
"""
147+
# Avoid changes to original tableau
148+
piv_row = tableau[row_idx].copy()
149+
150+
piv_val = piv_row[col_idx]
151+
152+
# Entry becomes 1
153+
piv_row *= 1 / piv_val
154+
155+
# Variable in pivot column becomes basic, ie the only non-zero entry
156+
for idx, coeff in enumerate(tableau[:, col_idx]):
157+
tableau[idx] += -coeff * piv_row
158+
tableau[row_idx] = piv_row
159+
return tableau
160+
161+
def change_stage(self, tableau: np.ndarray) -> np.ndarray:
162+
"""Exits first phase of the two-stage method by deleting artificial
163+
rows and columns, or completes the algorithm if exiting the standard
164+
case.
165+
166+
>>> t = Tableau(np.array([
167+
... [3, 3, -1, -1, 0, 0, 4],
168+
... [2, 1, 0, 0, 0, 0, 0.],
169+
... [1, 2, -1, 0, 1, 0, 2],
170+
... [2, 1, 0, -1, 0, 1, 2]
171+
... ]), 2)
172+
>>> t.change_stage(t.tableau).tolist()
173+
... # doctest: +NORMALIZE_WHITESPACE
174+
[[2.0, 1.0, 0.0, 0.0, 0.0, 0.0],
175+
[1.0, 2.0, -1.0, 0.0, 1.0, 2.0],
176+
[2.0, 1.0, 0.0, -1.0, 0.0, 2.0]]
177+
"""
178+
# Objective of original objective row remains
179+
self.objectives.pop()
180+
181+
if not self.objectives:
182+
return tableau
183+
184+
# Slice containing ids for artificial columns
185+
s = slice(-self.n_art_vars - 1, -1)
186+
187+
# Delete the artificial variable columns
188+
tableau = np.delete(tableau, s, axis=1)
189+
190+
# Delete the objective row of the first stage
191+
tableau = np.delete(tableau, 0, axis=0)
192+
193+
self.n_stages = 1
194+
self.n_rows -= 1
195+
self.n_art_vars = 0
196+
self.stop_iter = False
197+
return tableau
198+
199+
def run_simplex(self) -> dict[Any, Any]:
200+
"""Operate on tableau until objective function cannot be
201+
improved further.
202+
203+
# Standard linear program:
204+
Max: x1 + x2
205+
ST: x1 + 3x2 <= 4
206+
3x1 + x2 <= 4
207+
>>> Tableau(np.array([[-1,-1,0,0,0],[1,3,1,0,4],[3,1,0,1,4.]]),
208+
... 2).run_simplex()
209+
{'P': 2.0, 'x1': 1.0, 'x2': 1.0}
210+
211+
# Optimal tableau input:
212+
>>> Tableau(np.array([
213+
... [0, 0, 0.25, 0.25, 2],
214+
... [0, 1, 0.375, -0.125, 1],
215+
... [1, 0, -0.125, 0.375, 1]
216+
... ]), 2).run_simplex()
217+
{'P': 2.0, 'x1': 1.0, 'x2': 1.0}
218+
219+
# Non-standard: >= constraints
220+
Max: 2x1 + 3x2 + x3
221+
ST: x1 + x2 + x3 <= 40
222+
2x1 + x2 - x3 >= 10
223+
- x2 + x3 >= 10
224+
>>> Tableau(np.array([
225+
... [2, 0, 0, 0, -1, -1, 0, 0, 20],
226+
... [-2, -3, -1, 0, 0, 0, 0, 0, 0],
227+
... [1, 1, 1, 1, 0, 0, 0, 0, 40],
228+
... [2, 1, -1, 0, -1, 0, 1, 0, 10],
229+
... [0, -1, 1, 0, 0, -1, 0, 1, 10.]
230+
... ]), 3).run_simplex()
231+
{'P': 70.0, 'x1': 10.0, 'x2': 10.0, 'x3': 20.0}
232+
233+
# Non standard: minimisation and equalities
234+
Min: x1 + x2
235+
ST: 2x1 + x2 = 12
236+
6x1 + 5x2 = 40
237+
>>> Tableau(np.array([
238+
... [8, 6, 0, -1, 0, -1, 0, 0, 52],
239+
... [1, 1, 0, 0, 0, 0, 0, 0, 0],
240+
... [2, 1, 1, 0, 0, 0, 0, 0, 12],
241+
... [2, 1, 0, -1, 0, 0, 1, 0, 12],
242+
... [6, 5, 0, 0, 1, 0, 0, 0, 40],
243+
... [6, 5, 0, 0, 0, -1, 0, 1, 40.]
244+
... ]), 2).run_simplex()
245+
{'P': 7.0, 'x1': 5.0, 'x2': 2.0}
246+
"""
247+
# Stop simplex algorithm from cycling.
248+
for _ in range(100):
249+
# Completion of each stage removes an objective. If both stages
250+
# are complete, then no objectives are left
251+
if not self.objectives:
252+
self.col_titles = self.generate_col_titles(
253+
self.n_vars, self.n_slack, self.n_art_vars
254+
)
255+
256+
# Find the values of each variable at optimal solution
257+
return self.interpret_tableau(self.tableau, self.col_titles)
258+
259+
row_idx, col_idx = self.find_pivot(self.tableau)
260+
261+
# If there are no more negative values in objective row
262+
if self.stop_iter:
263+
# Delete artificial variable columns and rows. Update attributes
264+
self.tableau = self.change_stage(self.tableau)
265+
else:
266+
self.tableau = self.pivot(self.tableau, row_idx, col_idx)
267+
return {}
268+
269+
def interpret_tableau(
270+
self, tableau: np.ndarray, col_titles: list[str]
271+
) -> dict[str, float]:
272+
"""Given the final tableau, add the corresponding values of the basic
273+
decision variables to the `output_dict`
274+
>>> tableau = np.array([
275+
... [0,0,0.875,0.375,5],
276+
... [0,1,0.375,-0.125,1],
277+
... [1,0,-0.125,0.375,1]
278+
... ])
279+
>>> t = Tableau(tableau, 2)
280+
>>> t.interpret_tableau(tableau, ["x1", "x2", "s1", "s2", "RHS"])
281+
{'P': 5.0, 'x1': 1.0, 'x2': 1.0}
282+
"""
283+
# P = RHS of final tableau
284+
output_dict = {"P": abs(tableau[0, -1])}
285+
286+
for i in range(self.n_vars):
287+
# Gives ids of nonzero entries in the ith column
288+
nonzero = np.nonzero(tableau[:, i])
289+
n_nonzero = len(nonzero[0])
290+
291+
# First entry in the nonzero ids
292+
nonzero_rowidx = nonzero[0][0]
293+
nonzero_val = tableau[nonzero_rowidx, i]
294+
295+
# If there is only one nonzero value in column, which is one
296+
if n_nonzero == nonzero_val == 1:
297+
rhs_val = tableau[nonzero_rowidx, -1]
298+
output_dict[col_titles[i]] = rhs_val
299+
300+
# Check for basic variables
301+
for title in col_titles:
302+
# Don't add RHS or slack variables to output dict
303+
if title[0] not in "R-s-a":
304+
output_dict.setdefault(title, 0)
305+
return output_dict
306+
307+
308+
if __name__ == "__main__":
309+
import doctest
310+
311+
doctest.testmod()

0 commit comments

Comments
 (0)