Skip to content

Commit

Permalink
more lint
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 14, 2024
1 parent fbaaf04 commit 1ae9275
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 42 deletions.
43 changes: 22 additions & 21 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,43 @@ def sankey(

num_col = len(data.columns)
data.columns = range(num_col) # force numeric column headings
N = int(num_col/2) # number of labels
num_side = int(num_col/2) # number of labels
num_flow = num_side - 1

# sizes
weight_sum = np.empty(N)
num_uniq = np.empty(N)
col_hgt = np.empty(N)
for ii in range(N):
weight_sum = np.empty(num_side)
num_uniq = np.empty(num_side)
col_hgt = np.empty(num_side)
for ii in range(num_side):
weight_sum[ii] = sum(data[2*ii+1])
num_uniq[ii] = len(pd.Series(data[2*ii]).unique())

for ii in range(N):
for ii in range(num_side):
col_hgt[ii] = weight_sum[ii] + (num_uniq[ii]-1)*barGap*max(weight_sum)

# overall dimensions
plot_height = max(col_hgt)
sub_width = plot_height/aspect
plotWidth = (
(N-1)*sub_width
(num_side-1)*sub_width
+ 2*sub_width*labelWidth
+ N*sub_width*barWidth
+ num_side*sub_width*barWidth
)

# offsets for alignment
voffset = np.empty(N)
voffset = np.empty(num_side)
if valign == "top":
vscale = 1
elif valign == "center":
vscale = 0.5
else: # bottom, or undefined
vscale = 0

for ii in range(N):
for ii in range(num_side):
voffset[ii] = vscale*(col_hgt[1] - col_hgt[ii])

# labels
label_record = data[range(0, 2*N, 2)].to_records(index=False)
label_record = data[range(0, 2*num_side, 2)].to_records(index=False)
flattened = [item for sublist in label_record for item in sublist]
flatcat = pd.Series(flattened).unique()

Expand All @@ -118,10 +119,10 @@ def sankey(
if ax is None:
ax = plt.gca()

for ii in range(N-1):
for ii in range(num_flow):

_sankey(
ii, N-1, data,
ii, num_flow, data,
titles=titles,
titleGap=titleGap,
titleSide=titleSide,
Expand Down Expand Up @@ -167,7 +168,7 @@ def sankey(


def _sankey(
ii, N, data,
ii, num_flow, data,
colorDict=None,
labelOrder=None,
fontsize=None,
Expand Down Expand Up @@ -318,15 +319,15 @@ def _sankey(
lw=0,
snap=True,
)
if ii < N-1: # inside labels
if ii < num_flow-1: # inside labels
ax.text(
xRight + (labelGap+barWidth)*xMax,
rbot + 0.5*rrr,
labelDict.get(rightLabel, rightLabel),
{'ha': 'left', 'va': 'center'},
fontsize=fontsize
)
if ii == N-1: # last time
if ii == num_flow-1: # last time
ax.text(
xRight + (labelGap+barWidth)*xMax,
rbot + 0.5*rrr,
Expand Down Expand Up @@ -446,7 +447,7 @@ def check_data_matches_labels(labels, data, side):
raise LabelMismatchError(side, msg)


def combineColours(c1, c2, N):
def combineColours(c1, c2, num_col):

colorArrayLen = 4
# if not [r,g,b,a] assume a hex string like "#rrggbb":
Expand All @@ -463,10 +464,10 @@ def combineColours(c1, c2, N):
b2 = int(c2[5:7], 16)/255
c2 = [r2, g2, b2, 1]

rr = np.linspace(c1[0], c2[0], N)
gg = np.linspace(c1[1], c2[1], N)
bb = np.linspace(c1[2], c2[2], N)
aa = np.linspace(c1[3], c2[3], N)
rr = np.linspace(c1[0], c2[0], num_col)
gg = np.linspace(c1[1], c2[1], num_col)
bb = np.linspace(c1[2], c2[2], num_col)
aa = np.linspace(c1[3], c2[3], num_col)

return np.array([rr, gg, bb, aa])

13 changes: 6 additions & 7 deletions sankey_doc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ausankey as sky

data = pd.read_csv('tests/fruit.csv')
print(data)

plt.figure()
sky.sankey(data)
Expand All @@ -25,7 +24,7 @@
plt.show()
plt.savefig("doc/fruits_jet.png")

colorDict = {
color_dict = {
'apple': '#f71b1b',
'blueberry': '#1b7ef7',
'banana': '#f3f71b',
Expand All @@ -34,11 +33,11 @@
}

plt.figure()
sky.sankey(data,colorDict=colorDict)
sky.sankey(data,colorDict=color_dict)
plt.show()
plt.savefig("doc/fruits_colordict.png")

labelDict = {
label_dict = {
'apple': 'Apple',
'blueberry': "B'berry",
'banana': 'Banana',
Expand All @@ -47,7 +46,7 @@
}

plt.figure()
sky.sankey(data,labelDict=labelDict)
sky.sankey(data,labelDict=label_dict)
plt.show()
plt.savefig("doc/fruits_labeldict.png")

Expand Down Expand Up @@ -127,7 +126,7 @@
plt.savefig("doc/frame2_sort_n1.png")


colorDict = {
color_dict = {
'a':'#f71b1b',
'b':'#1b7ef7',
'ab':'#8821aa',
Expand All @@ -153,7 +152,7 @@
sky.sankey(
data,
sorting = -1,
colorDict = colorDict,
colorDict = color_dict,
labelWidth = 0.1,
labelGap = 0.02,
barWidth = 0.05,
Expand Down
22 changes: 11 additions & 11 deletions tests/test_fruit_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,36 @@ def test_fruits_default(self):
sky.sankey(self.data)

def test_fruits_sorting(self):

plt.figure(dpi=150)
sky.sankey(self.data, sorting=1)

plt.figure(dpi=150)
sky.sankey(self.data, sorting=-1)

def test_fruits_colormap(self):

plt.figure(dpi=150)
sky.sankey(self.data, colormap="jet")

def test_fruits_colordict(self):

plt.figure(dpi=150)
sky.sankey(self.data, colorDict=self.colorDict)
sky.sankey(self.data, colorDict=self.color_dict)

def test_fruits_titles(self):

plt.figure(dpi=150)
sky.sankey(self.data, titles=["Summer", "Winter"])

def test_fruits_valign(self):

plt.figure(dpi=150)
sky.sankey(self.data, valign="top")

plt.figure(dpi=150)
sky.sankey(self.data, valign="center")

plt.figure(dpi=150)
sky.sankey(self.data, valign="bottom")

7 changes: 4 additions & 3 deletions tests/test_fruit_setup.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pandas as pd

from .generic_test import GenericTest

import pandas as pd

class TestFruit(GenericTest):
""" Setup sankey test with data in fruit.csv """

def setUp(self):

self.figure_name = "fruit"
self.data = pd.read_csv(
'tests/fruit.csv', sep=','
)
self.colorDict = {
self.color_dict = {
'apple': '#f71b1b',
'blueberry': '#1b7ef7',
'banana': '#f3f71b',
Expand Down

0 comments on commit 1ae9275

Please sign in to comment.