Skip to content

Commit

Permalink
style fixes by ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr authored and github-actions[bot] committed Mar 28, 2024
1 parent e1bcbed commit 742ccdf
Showing 1 changed file with 55 additions and 56 deletions.
111 changes: 55 additions & 56 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pandas as pd


def sankey(data,**kwargs):
def sankey(data, **kwargs):
"""Make Sankey Diagram
Parameters
Expand All @@ -27,7 +27,7 @@ def sankey(data,**kwargs):
"""

opt = Sankey(**kwargs)
opt = opt.setup(data,opt)
opt = opt.setup(data, opt)

# draw each segment
for ii in range(opt.num_flow):
Expand All @@ -38,7 +38,7 @@ class SankeyError(Exception):
pass


class Sankey():
class Sankey:
"""Sankey Diagram
Parameters
Expand Down Expand Up @@ -228,7 +228,7 @@ def setup(self, data, opt):
data.columns = range(num_col) # force numeric column headings
num_side = int(num_col / 2) # number of stages
opt.num_flow = num_side - 1

# sizes
weight_sum = np.empty(num_side)
num_uniq = np.empty(num_side)
Expand All @@ -238,7 +238,7 @@ def setup(self, data, opt):
for ii in range(num_side):
nodes_uniq[ii] = pd.Series(data[2 * ii]).unique()
num_uniq[ii] = len(nodes_uniq[ii])

for ii in range(num_side):
opt.node_sizes[ii] = {}
for lbl in nodes_uniq[ii]:
Expand All @@ -261,10 +261,10 @@ def setup(self, data, opt):
opt.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)
opt.node_sizes[ii] = sort_dict(opt.node_sizes[ii], opt.sort)
weight_sum[ii] = pd.Series(opt.node_sizes[ii].values()).sum()

for ii in range(num_side):
col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * opt.node_gap * max(weight_sum)

# overall dimensions
opt.plot_height = max(col_hgt)
opt.sub_width = opt.plot_height / opt.aspect
Expand All @@ -273,7 +273,7 @@ def setup(self, data, opt):
+ 2 * opt.sub_width * (opt.label_gap + opt.label_width)
+ num_side * opt.sub_width * opt.node_width
)

# offsets for alignment
opt.voffset = np.empty(num_side)
if opt.valign == "top":
Expand All @@ -282,15 +282,15 @@ def setup(self, data, opt):
vscale = 0.5
else: # bottom, or undefined
vscale = 0

for ii in range(num_side):
opt.voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii])

# labels
label_record = data[range(0, 2 * num_side, 2)].to_records(index=False)
flattened = [item for sublist in label_record for item in sublist]
flatcat = pd.Series(flattened).unique()

# If no color_dict given, make one
color_dict_orig = opt.color_dict or {}
color_dict_new = {}
Expand All @@ -299,35 +299,35 @@ def setup(self, data, opt):
for i, label in enumerate(flatcat):
color_dict_new[label] = color_dict_orig.get(label, color_palette[i])
opt.color_dict = color_dict_new

# initialise plot
opt.ax = opt.ax or plt.gca()
opt.ax.axis("off")

# frame on top/bottom edge
frame_top = opt.frame_side in ("top", "both")
frame_bot = opt.frame_side in ("bottom", "both")

frame_color = opt.frame_color or [0, 0, 0, 1]

y_frame_gap = opt.frame_gap * opt.plot_height

col = frame_color if frame_top else [1, 1, 1, 0]
opt.ax.plot(
[0, plot_width],
min(opt.voffset) + (opt.plot_height) + y_frame_gap + [0, 0],
color=col,
lw=opt.frame_lw,
)

col = frame_color if frame_bot else [1, 1, 1, 0]
opt.ax.plot(
[0, plot_width],
min(opt.voffset) - y_frame_gap + [0, 0],
color=col,
lw=opt.frame_lw,
)

if opt.label_loc is None:
opt.label_loc = ["left", "none", "right"]
if opt.node_edge is None:
Expand All @@ -340,20 +340,19 @@ def setup(self, data, opt):
opt.label_dict = {}
if opt.label_font is None:
opt.label_font = {}

return opt

return opt

def subplot(self,ii, data, opt):
def subplot(self, ii, data, opt):
"""Subroutine for plotting horizontal sections of the Sankey plot
Some special-casing is used for plotting/labelling differently
for the first and last cases.
"""

labelind = 2 * ii
weightind = 2 * ii + 1

if ii < opt.num_flow - 1:
labels_lr = [
pd.Series(data[labelind]),
Expand All @@ -376,15 +375,15 @@ def subplot(self,ii, data, opt):
pd.Series(data[weightind + 2]),
pd.Series(data[weightind + 2]),
]

nodes_lr = [
sort_nodes(labels_lr[0], opt.node_sizes[ii]),
sort_nodes(labels_lr[1], opt.node_sizes[ii + 1]),
]

# check colours
check_colors_match_labels(labels_lr, opt.color_dict)

# Determine sizes of individual subflows
nodesize = [{}, {}]
for lbl_l in nodes_lr[0]:
Expand All @@ -394,21 +393,21 @@ def subplot(self,ii, data, opt):
ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r)
nodesize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum()
nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum()

# Determine vertical positions of nodes
y_node_gap = opt.node_gap * opt.plot_height

if opt.valign == "top":
vscale = 1
elif opt.valign == "center":
vscale = 0.5
else: # bottom, or undefined
vscale = 0

node_voffset = [{}, {}]
node_pos_bot = [{}, {}]
node_pos_top = [{}, {}]

for lr in [0, 1]:
for i, label in enumerate(nodes_lr[lr]):
node_height = opt.node_sizes[ii + lr][label]
Expand All @@ -417,16 +416,16 @@ def subplot(self,ii, data, opt):
next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0
node_pos_bot[lr][label] = opt.voffset[ii + lr] if i == 0 else next_bot
node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height

# horizontal positions of nodes
x_node_width = opt.node_width * opt.sub_width
x_label_width = opt.label_width * opt.sub_width
x_label_gap = opt.label_gap * opt.sub_width
x_left = x_node_width + x_label_gap + x_label_width + ii * (opt.sub_width + x_node_width)
x_lr = [x_left, x_left + opt.sub_width]

# Draw nodes

def draw_node(x, dx, y, dy, label):
edge_lw = opt.node_lw if opt.node_edge else 0
opt.ax.fill_between(
Expand All @@ -448,7 +447,7 @@ def draw_node(x, dx, y, dy, label):
lw=edge_lw,
snap=True,
)

for lr in [0, 1] if ii == 0 else [1]:
for label in nodes_lr[lr]:
draw_node(
Expand All @@ -458,9 +457,9 @@ def draw_node(x, dx, y, dy, label):
opt.node_sizes[ii + lr][label],
label,
)

# Draw node labels

def draw_label(x, y, label, ha):
opt.ax.text(
x,
Expand All @@ -475,9 +474,9 @@ def draw_label(x, y, label, ha):
**opt.label_font,
},
)

ha_dict = {"left": "right", "right": "left", "center": "center"}

# first row of labels
lr = 0
if ii == 0 and opt.label_loc[0] != "none":
Expand All @@ -490,7 +489,7 @@ def draw_label(x, y, label, ha):
for label in nodes_lr[lr]:
yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha_dict[opt.label_loc[0]])

# inside labels, left
lr = 1
if ii < opt.num_flow - 1 and opt.label_loc[1] in ("left", "both"):
Expand All @@ -499,23 +498,23 @@ def draw_label(x, y, label, ha):
for label in nodes_lr[lr]:
yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha)

# inside labels, center
if ii < opt.num_flow - 1 and opt.label_loc[1] in ("center"):
xx = x_lr[lr] + x_node_width / 2
ha = "center"
for label in nodes_lr[lr]:
yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha)

# inside labels, right
if ii < opt.num_flow - 1 and opt.label_loc[1] in ("right", "both"):
xx = x_lr[lr] + x_label_gap + x_node_width
ha = "left"
for label in nodes_lr[lr]:
yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha)

# last row of labels
if ii == opt.num_flow - 1 and opt.label_loc[2] != "none":
if opt.label_loc[2] in ("left"):
Expand All @@ -527,38 +526,38 @@ def draw_label(x, y, label, ha):
for label in nodes_lr[lr]:
yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha_dict[opt.label_loc[2]])

# Plot flows
if opt.flow_edge:
edge_lw = opt.flow_lw
edge_alpha = 1
else:
edge_alpha = opt.flow_alpha
edge_lw = 0

for lbl_l in nodes_lr[0]:
for lbl_r in nodes_lr[1]:
lind = labels_lr[0] == lbl_l
rind = labels_lr[1] == lbl_r
if not any(lind & rind):
continue

lbot = node_voffset[0][lbl_l] + node_pos_bot[0][lbl_l]
rbot = node_voffset[1][lbl_r] + node_pos_bot[1][lbl_r]
llen = nodesize[0][lbl_l][lbl_r]
rlen = nodesize[1][lbl_l][lbl_r]

ys_d = create_curve(lbot, rbot)
ys_u = create_curve(lbot + llen, rbot + rlen)

# Update bottom edges at each label
# so next strip starts at the right place
node_pos_bot[0][lbl_l] += llen
node_pos_bot[1][lbl_r] += rlen

xx = np.linspace(x_lr[0], x_lr[1], len(ys_d))
cc = combine_colours(opt.color_dict[lbl_l], opt.color_dict[lbl_r], len(ys_d))

for jj in range(len(ys_d) - 1):
opt.ax.fill_between(
xx[[jj, jj + 1]],
Expand Down Expand Up @@ -587,19 +586,19 @@ def draw_label(x, y, label, ha):
lw=edge_lw,
snap=True,
)

# Place "titles"
if opt.titles is not None:
last_label = [lbl_l, lbl_r]

y_title_gap = opt.title_gap * opt.plot_height
y_frame_gap = opt.frame_gap * opt.plot_height

title_x = [
x_lr[0] - x_node_width / 2,
x_lr[1] + x_node_width / 2,
]

def draw_title(x, y, label, va):
opt.ax.text(
x,
Expand All @@ -614,18 +613,18 @@ def draw_title(x, y, label, va):
**opt.title_font,
},
)

# leftmost title
title_lr = [0, 1] if ii == 0 else [1]

for lr in title_lr:
if opt.title_side in ("top", "both"):
if opt.title_loc == "outer":
yt = min(opt.voffset) + y_title_gap + y_frame_gap + opt.plot_height
elif opt.title_loc == "inner":
yt = y_title_gap + node_pos_top[lr][last_label[lr]]
draw_title(title_x[lr], yt, opt.titles[ii + lr], "bottom")

if opt.title_side in ("bottom", "both"):
if opt.title_loc == "outer":
yt = min(opt.voffset) - y_title_gap - y_frame_gap
Expand Down

0 comments on commit 742ccdf

Please sign in to comment.