Skip to content

Commit

Permalink
another step towards oop
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 28, 2024
1 parent 1546c0d commit 90272f9
Showing 1 changed file with 43 additions and 73 deletions.
116 changes: 43 additions & 73 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ class Sankey:

def __init__(
self,
data,
aspect=4,
ax=None,
node_width=0.02,
Expand Down Expand Up @@ -170,7 +169,6 @@ def __init__(
sort="bottom", # "top", "bottom", "none"
valign="bottom", # "top","center"
):
self.data = data
self.aspect = aspect
self.ax = ax
self.node_width = node_width
Expand Down Expand Up @@ -204,39 +202,7 @@ def __init__(
self.valign = valign


def sankey(
data,
aspect=4,
ax=None,
node_width=0.02,
node_gap=0.05,
node_alpha=1,
node_edge=None,
color_dict=None,
colormap="viridis",
flow_edge=None,
flow_alpha=0.6,
fontcolor="black",
fontfamily="sans-serif",
fontsize=12,
frame_side="none",
frame_gap=0.1,
frame_color=None,
label_dict=None,
label_width=0,
label_gap=0.02,
label_loc=None,
label_font=None,
flow_lw=1,
node_lw=1,
titles=None,
title_gap=0.05,
title_side="top", # "bottom", "both"
title_loc="inner", # "outer"
title_font=None,
sort="bottom", # "top", "bottom", "none"
valign="bottom", # "top","center"
):
def sankey(data,**kwargs):
"""Make Sankey Diagram with left-right flow
Parameters
Expand All @@ -250,6 +216,8 @@ def sankey(
None (yet)
"""

opt = Sankey(**kwargs)

num_col = len(data.columns)
data.columns = range(num_col) # force numeric column headings
num_side = int(num_col / 2) # number of labels
Expand Down Expand Up @@ -285,24 +253,24 @@ def sankey(
weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)
node_sizes[ii] = sort_dict(node_sizes[ii], sort)
node_sizes[ii] = sort_dict(node_sizes[ii], opt.sort)
weight_sum[ii] = pd.Series(node_sizes[ii].values()).sum()

for ii in range(num_side):
col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * node_gap * max(weight_sum)
col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * opt.node_gap * max(weight_sum)

# overall dimensions
plot_height = max(col_hgt)
sub_width = plot_height / aspect
sub_width = plot_height / opt.aspect
plot_width = (
(num_side - 1) * sub_width + 2 * sub_width * (label_gap + label_width) + num_side * sub_width * node_width
(num_side - 1) * sub_width + 2 * sub_width * (opt.label_gap + opt.label_width) + num_side * sub_width * opt.node_width

Check failure on line 266 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

ausankey/ausankey.py:266:121: E501 Line too long (126 > 120)
)

# offsets for alignment
voffset = np.empty(num_side)
if valign == "top":
if opt.valign == "top":
vscale = 1
elif valign == "center":
elif opt.valign == "center":
vscale = 0.5
else: # bottom, or undefined
vscale = 0
Expand All @@ -316,36 +284,38 @@ def sankey(
flatcat = pd.Series(flattened).unique()

# If no color_dict given, make one
color_dict = color_dict or {}
color_dict = opt.color_dict or {}

Check failure on line 287 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F841)

ausankey/ausankey.py:287:5: F841 Local variable `color_dict` is assigned to but never used
color_dict_new = {}
cmap = plt.cm.get_cmap(colormap)
cmap = plt.cm.get_cmap(opt.colormap)
color_palette = cmap(np.linspace(0, 1, len(flatcat)))
for i, label in enumerate(flatcat):
color_dict_new[label] = color_dict.get(label, color_palette[i])
color_dict_new[label] = opt.color_dict.get(label, color_palette[i])

# initialise plot
ax = ax or plt.gca()
ax = opt.ax or plt.gca()

# frame on top/bottom edge
frame_top = frame_side in ("top", "both")
frame_bot = frame_side in ("bottom", "both")
frame_top = opt.frame_side in ("top", "both")
frame_bot = opt.frame_side in ("bottom", "both")

frame_color = frame_color or [0, 0, 0, 1]
frame_color = opt.frame_color or [0, 0, 0, 1]

y_frame_gap = frame_gap * plot_height
y_frame_gap = opt.frame_gap * plot_height

col = frame_color if frame_top else [1, 1, 1, 0]
ax.plot(
[0, plot_width],
min(voffset) + (plot_height) + y_frame_gap + [0, 0],
color=col,
lw=opt.frame_lw,
)

col = frame_color if frame_bot else [1, 1, 1, 0]
ax.plot(
[0, plot_width],
min(voffset) - y_frame_gap + [0, 0],
color=col,
lw=opt.frame_lw,
)

# draw each segment
Expand All @@ -356,33 +326,33 @@ def sankey(
data,
ax=ax,
color_dict=color_dict_new,
flow_edge=flow_edge or False,
flow_alpha=flow_alpha,
frame_gap=frame_gap,
fontcolor=fontcolor,
fontfamily=fontfamily,
fontsize=fontsize,
label_dict=label_dict or {},
label_width=label_width,
label_gap=label_gap,
label_loc=label_loc or ["left", "none", "right"],
label_font=label_font or {},
flow_lw=flow_lw,
node_lw=node_lw,
node_alpha=node_alpha,
node_width=node_width,
flow_edge=opt.flow_edge or False,
flow_alpha=opt.flow_alpha,
frame_gap=opt.frame_gap,
fontcolor=opt.fontcolor,
fontfamily=opt.fontfamily,
fontsize=opt.fontsize,
label_dict=opt.label_dict or {},
label_width=opt.label_width,
label_gap=opt.label_gap,
label_loc=opt.label_loc or ["left", "none", "right"],
label_font=opt.label_font or {},
flow_lw=opt.flow_lw,
node_lw=opt.node_lw,
node_alpha=opt.node_alpha,
node_width=opt.node_width,
node_gap=opt.node_gap,
node_edge=opt.node_edge or False,
titles=opt.titles,
title_gap=opt.title_gap,
title_side=opt.title_side,
title_loc=opt.title_loc,
title_font=opt.title_font or {"fontweight": "bold"},
voffset=voffset,
valign=opt.valign,
node_sizes=node_sizes,
node_gap=node_gap,
node_edge=node_edge or False,
plot_height=plot_height,
sub_width=sub_width,
titles=titles,
title_gap=title_gap,
title_side=title_side,
title_loc=title_loc,
title_font=title_font or {"fontweight": "bold"},
voffset=voffset,
valign=valign,
)

# complete plot
Expand Down

0 comments on commit 90272f9

Please sign in to comment.