From 7c0780aa152bb2e3cca3ce637e5d67c156ed88f2 Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Fri, 29 Mar 2024 08:32:30 +1030 Subject: [PATCH] remove keyval from subfunction --- ausankey/ausankey.py | 280 ++++++++++++++++++------------------------- 1 file changed, 114 insertions(+), 166 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 055e8f1..cec4679 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -16,7 +16,7 @@ class SankeyError(Exception): pass -class Sankey: +class Sankey(): """Sankey Diagram Parameters @@ -220,27 +220,27 @@ def sankey(data,**kwargs): num_col = len(data.columns) data.columns = range(num_col) # force numeric column headings - num_side = int(num_col / 2) # number of labels - num_flow = num_side - 1 + num_side = int(num_col / 2) # number of stages + opt.num_flow = num_side - 1 # sizes weight_sum = np.empty(num_side) num_uniq = np.empty(num_side) col_hgt = np.empty(num_side) - node_sizes = {} + opt.node_sizes = {} nodes_uniq = {} for ii in range(num_side): nodes_uniq[ii] = pd.Series(data[2 * ii]).unique() num_uniq[ii] = len(nodes_uniq[ii]) for ii in range(num_side): - node_sizes[ii] = {} + opt.node_sizes[ii] = {} for lbl in nodes_uniq[ii]: if ii == 0: ind_prev = data[2 * ii + 0] == lbl ind_this = data[2 * ii + 0] == lbl ind_next = data[2 * ii + 2] == lbl - elif ii == num_flow: + elif ii == opt.num_flow: ind_prev = data[2 * ii - 2] == lbl ind_this = data[2 * ii + 0] == lbl ind_next = data[2 * ii + 0] == lbl @@ -252,22 +252,22 @@ def sankey(data,**kwargs): weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum() weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum() weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() - node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) - node_sizes[ii] = sort_dict(node_sizes[ii], opt.sort) - weight_sum[ii] = pd.Series(node_sizes[ii].values()).sum() + opt.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) + opt.node_sizes[ii] = sort_dict(opt.node_sizes[ii], opt.sort) + weight_sum[ii] = pd.Series(opt.node_sizes[ii].values()).sum() for ii in range(num_side): col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * opt.node_gap * max(weight_sum) # overall dimensions - plot_height = max(col_hgt) - sub_width = plot_height / opt.aspect + opt.plot_height = max(col_hgt) + opt.sub_width = opt.plot_height / opt.aspect plot_width = ( - (num_side - 1) * sub_width + 2 * sub_width * (opt.label_gap + opt.label_width) + num_side * sub_width * opt.node_width + (num_side - 1) * opt.sub_width + 2 * opt.sub_width * (opt.label_gap + opt.label_width) + num_side * opt.sub_width * opt.node_width ) # offsets for alignment - voffset = np.empty(num_side) + opt.voffset = np.empty(num_side) if opt.valign == "top": vscale = 1 elif opt.valign == "center": @@ -276,7 +276,7 @@ def sankey(data,**kwargs): vscale = 0 for ii in range(num_side): - voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii]) + opt.voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii]) # labels label_record = data[range(0, 2 * num_side, 2)].to_records(index=False) @@ -284,15 +284,16 @@ def sankey(data,**kwargs): flatcat = pd.Series(flattened).unique() # If no color_dict given, make one - color_dict = opt.color_dict or {} + color_dict_orig = opt.color_dict or {} color_dict_new = {} cmap = plt.cm.get_cmap(opt.colormap) color_palette = cmap(np.linspace(0, 1, len(flatcat))) for i, label in enumerate(flatcat): - color_dict_new[label] = opt.color_dict.get(label, color_palette[i]) + color_dict_new[label] = color_dict_orig.get(label, color_palette[i]) + opt.color_dict = color_dict_new # initialise plot - ax = opt.ax or plt.gca() + opt.ax = opt.ax or plt.gca() # frame on top/bottom edge frame_top = opt.frame_side in ("top", "both") @@ -300,102 +301,49 @@ def sankey(data,**kwargs): frame_color = opt.frame_color or [0, 0, 0, 1] - y_frame_gap = opt.frame_gap * plot_height + y_frame_gap = opt.frame_gap * opt.plot_height col = frame_color if frame_top else [1, 1, 1, 0] - ax.plot( + opt.ax.plot( [0, plot_width], - min(voffset) + (plot_height) + y_frame_gap + [0, 0], + min(opt.voffset) + (opt.plot_height) + y_frame_gap + [0, 0], color=col, lw=opt.frame_lw, ) col = frame_color if frame_bot else [1, 1, 1, 0] - ax.plot( + opt.ax.plot( [0, plot_width], - min(voffset) - y_frame_gap + [0, 0], + min(opt.voffset) - y_frame_gap + [0, 0], color=col, lw=opt.frame_lw, ) + + if opt.label_loc is None: + opt.label_loc = ["left", "none", "right"] + if opt.node_edge is None: + opt.node_edge = False + if opt.flow_edge is None: + opt.flow_edge = False + if opt.title_font is None: + opt.title_font = {"fontweight": "bold"} + if opt.label_dict is None: + opt.label_dict = {} + if opt.label_font is None: + opt.label_font = {} # draw each segment - for ii in range(num_flow): - _sankey( - ii, - num_flow, - data, - ax=ax, - color_dict=color_dict_new, - flow_edge=opt.flow_edge or False, - flow_alpha=opt.flow_alpha, - frame_gap=opt.frame_gap, - fontcolor=opt.fontcolor, - fontfamily=opt.fontfamily, - fontsize=opt.fontsize, - label_dict=opt.label_dict or {}, - label_width=opt.label_width, - label_gap=opt.label_gap, - label_loc=opt.label_loc or ["left", "none", "right"], - label_font=opt.label_font or {}, - flow_lw=opt.flow_lw, - node_lw=opt.node_lw, - node_alpha=opt.node_alpha, - node_width=opt.node_width, - node_gap=opt.node_gap, - node_edge=opt.node_edge or False, - titles=opt.titles, - title_gap=opt.title_gap, - title_side=opt.title_side, - title_loc=opt.title_loc, - title_font=opt.title_font or {"fontweight": "bold"}, - voffset=voffset, - valign=opt.valign, - node_sizes=node_sizes, - plot_height=plot_height, - sub_width=sub_width, - ) + for ii in range(opt.num_flow): + _sankey(ii, data, opt) # complete plot - ax.axis("off") + opt.ax.axis("off") ########################################### -def _sankey( - ii, - num_flow, - data, - ax=None, - color_dict=None, - flow_edge=None, - flow_alpha=None, - fontcolor=None, - fontsize=None, - fontfamily=None, - frame_gap=None, - label_dict=None, - label_width=None, - label_gap=None, - label_loc=None, - label_font=None, - flow_lw=None, - node_lw=None, - node_width=None, - node_sizes=None, - node_gap=None, - node_alpha=None, - node_edge=None, - plot_height=None, - sub_width=None, - titles=None, - title_gap=None, - title_side=None, - title_loc=None, - title_font=None, - voffset=None, - valign=None, -): +def _sankey(ii, data, opt): """Subroutine for plotting horizontal sections of the Sankey plot Some special-casing is used for plotting/labelling differently @@ -405,7 +353,7 @@ def _sankey( labelind = 2 * ii weightind = 2 * ii + 1 - if ii < num_flow - 1: + if ii < opt.num_flow - 1: labels_lr = [ pd.Series(data[labelind]), pd.Series(data[labelind + 2]), @@ -429,12 +377,12 @@ def _sankey( ] nodes_lr = [ - sort_nodes(labels_lr[0], node_sizes[ii]), - sort_nodes(labels_lr[1], node_sizes[ii + 1]), + sort_nodes(labels_lr[0], opt.node_sizes[ii]), + sort_nodes(labels_lr[1], opt.node_sizes[ii + 1]), ] # check colours - check_colors_match_labels(labels_lr, color_dict) + check_colors_match_labels(labels_lr, opt.color_dict) # Determine sizes of individual subflows nodesize = [{}, {}] @@ -447,11 +395,11 @@ def _sankey( nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() # Determine vertical positions of nodes - y_node_gap = node_gap * plot_height + y_node_gap = opt.node_gap * opt.plot_height - if valign == "top": + if opt.valign == "top": vscale = 1 - elif valign == "center": + elif opt.valign == "center": vscale = 0.5 else: # bottom, or undefined vscale = 0 @@ -462,39 +410,39 @@ def _sankey( for lr in [0, 1]: for i, label in enumerate(nodes_lr[lr]): - node_height = node_sizes[ii + lr][label] + node_height = opt.node_sizes[ii + lr][label] this_side_height = weights_lr[lr][labels_lr[lr] == label].sum() node_voffset[lr][label] = vscale * (node_height - this_side_height) next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0 - node_pos_bot[lr][label] = voffset[ii + lr] if i == 0 else next_bot + node_pos_bot[lr][label] = opt.voffset[ii + lr] if i == 0 else next_bot node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height # horizontal positions of nodes - x_node_width = node_width * sub_width - x_label_width = label_width * sub_width - x_label_gap = label_gap * sub_width - x_left = x_node_width + x_label_gap + x_label_width + ii * (sub_width + x_node_width) - x_lr = [x_left, x_left + sub_width] + x_node_width = opt.node_width * opt.sub_width + x_label_width = opt.label_width * opt.sub_width + x_label_gap = opt.label_gap * opt.sub_width + x_left = x_node_width + x_label_gap + x_label_width + ii * (opt.sub_width + x_node_width) + x_lr = [x_left, x_left + opt.sub_width] # Draw nodes def draw_node(x, dx, y, dy, label): - edge_lw = node_lw if node_edge else 0 - ax.fill_between( + edge_lw = opt.node_lw if opt.node_edge else 0 + opt.ax.fill_between( [x, x + dx], y, y + dy, - facecolor=color_dict[label], - alpha=node_alpha, + facecolor=opt.color_dict[label], + alpha=opt.node_alpha, lw=edge_lw, snap=True, ) - if node_edge: - ax.fill_between( + if opt.node_edge: + opt.ax.fill_between( [x, x + dx], y, y + dy, - edgecolor=color_dict[label], + edgecolor=opt.color_dict[label], facecolor="none", lw=edge_lw, snap=True, @@ -506,24 +454,24 @@ def draw_node(x, dx, y, dy, label): x_lr[lr] - x_node_width * (1 - lr), x_node_width, node_pos_bot[lr][label], - node_sizes[ii + lr][label], + opt.node_sizes[ii + lr][label], label, ) # Draw node labels def draw_label(x, y, label, ha): - ax.text( + opt.ax.text( x, y, - label_dict.get(label, label), + opt.label_dict.get(label, label), { "ha": ha, "va": "center", - "fontfamily": fontfamily, - "fontsize": fontsize, - "color": fontcolor, - **label_font, + "fontfamily": opt.fontfamily, + "fontsize": opt.fontsize, + "color": opt.fontcolor, + **opt.label_font, }, ) @@ -531,60 +479,60 @@ def draw_label(x, y, label, ha): # first row of labels lr = 0 - if ii == 0 and label_loc[0] != "none": - if label_loc[0] in ("left"): + if ii == 0 and opt.label_loc[0] != "none": + if opt.label_loc[0] in ("left"): xx = x_lr[lr] - x_label_gap - x_node_width - elif label_loc[0] in ("right"): + elif opt.label_loc[0] in ("right"): xx = x_lr[lr] + x_label_gap - elif label_loc[0] in ("center"): + elif opt.label_loc[0] in ("center"): xx = x_lr[lr] - x_node_width / 2 for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha_dict[label_loc[0]]) + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha_dict[opt.label_loc[0]]) # inside labels, left lr = 1 - if ii < num_flow - 1 and label_loc[1] in ("left", "both"): + if ii < opt.num_flow - 1 and opt.label_loc[1] in ("left", "both"): xx = x_lr[lr] - x_label_gap ha = "right" for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2 + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 draw_label(xx, yy, label, ha) # inside labels, center - if ii < num_flow - 1 and label_loc[1] in ("center"): + if ii < opt.num_flow - 1 and opt.label_loc[1] in ("center"): xx = x_lr[lr] + x_node_width / 2 ha = "center" for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2 + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 draw_label(xx, yy, label, ha) # inside labels, right - if ii < num_flow - 1 and label_loc[1] in ("right", "both"): + if ii < opt.num_flow - 1 and opt.label_loc[1] in ("right", "both"): xx = x_lr[lr] + x_label_gap + x_node_width ha = "left" for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2 + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 draw_label(xx, yy, label, ha) # last row of labels - if ii == num_flow - 1 and label_loc[2] != "none": - if label_loc[2] in ("left"): + if ii == opt.num_flow - 1 and opt.label_loc[2] != "none": + if opt.label_loc[2] in ("left"): xx = x_lr[lr] - x_label_gap - elif label_loc[2] in ("right"): + elif opt.label_loc[2] in ("right"): xx = x_lr[lr] + x_label_gap + x_node_width - elif label_loc[2] in ("center"): + elif opt.label_loc[2] in ("center"): xx = x_lr[lr] + x_node_width / 2 for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha_dict[label_loc[2]]) + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha_dict[opt.label_loc[2]]) # Plot flows - if flow_edge: - edge_lw = flow_lw + if opt.flow_edge: + edge_lw = opt.flow_lw edge_alpha = 1 else: - edge_alpha = flow_alpha + edge_alpha = opt.flow_alpha edge_lw = 0 for lbl_l in nodes_lr[0]: @@ -608,21 +556,21 @@ def draw_label(x, y, label, ha): node_pos_bot[1][lbl_r] += rlen xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) - cc = combine_colours(color_dict[lbl_l], color_dict[lbl_r], len(ys_d)) + cc = combine_colours(opt.color_dict[lbl_l], opt.color_dict[lbl_r], len(ys_d)) for jj in range(len(ys_d) - 1): - ax.fill_between( + opt.ax.fill_between( xx[[jj, jj + 1]], ys_d[[jj, jj + 1]], ys_u[[jj, jj + 1]], color=cc[:, jj], - alpha=flow_alpha, + alpha=opt.flow_alpha, lw=0, edgecolor="none", snap=True, ) # edges: - ax.plot( + opt.ax.plot( xx[[jj, jj + 1]], ys_d[[jj, jj + 1]], color=cc[:, jj], @@ -630,7 +578,7 @@ def draw_label(x, y, label, ha): lw=edge_lw, snap=True, ) - ax.plot( + opt.ax.plot( xx[[jj, jj + 1]], ys_u[[jj, jj + 1]], color=cc[:, jj], @@ -640,11 +588,11 @@ def draw_label(x, y, label, ha): ) # Place "titles" - if titles is not None: + if opt.titles is not None: last_label = [lbl_l, lbl_r] - y_title_gap = title_gap * plot_height - y_frame_gap = frame_gap * plot_height + y_title_gap = opt.title_gap * opt.plot_height + y_frame_gap = opt.frame_gap * opt.plot_height title_x = [ x_lr[0] - x_node_width / 2, @@ -652,17 +600,17 @@ def draw_label(x, y, label, ha): ] def draw_title(x, y, label, va): - ax.text( + opt.ax.text( x, y, label, { "ha": "center", "va": va, - "fontfamily": fontfamily, - "fontsize": fontsize, - "color": fontcolor, - **title_font, + "fontfamily": opt.fontfamily, + "fontsize": opt.fontsize, + "color": opt.fontcolor, + **opt.title_font, }, ) @@ -670,19 +618,19 @@ def draw_title(x, y, label, va): title_lr = [0, 1] if ii == 0 else [1] for lr in title_lr: - if title_side in ("top", "both"): - if title_loc == "outer": - yt = min(voffset) + y_title_gap + y_frame_gap + plot_height - elif title_loc == "inner": + if opt.title_side in ("top", "both"): + if opt.title_loc == "outer": + yt = min(opt.voffset) + y_title_gap + y_frame_gap + opt.plot_height + elif opt.title_loc == "inner": yt = y_title_gap + node_pos_top[lr][last_label[lr]] - draw_title(title_x[lr], yt, titles[ii + lr], "bottom") - - if title_side in ("bottom", "both"): - if title_loc == "outer": - yt = min(voffset) - y_title_gap - y_frame_gap - elif title_loc == "inner": - yt = voffset[ii + lr] - y_title_gap - draw_title(title_x[lr], yt, titles[ii + lr], "top") + draw_title(title_x[lr], yt, opt.titles[ii + lr], "bottom") + + if opt.title_side in ("bottom", "both"): + if opt.title_loc == "outer": + yt = min(opt.voffset) - y_title_gap - y_frame_gap + elif opt.title_loc == "inner": + yt = opt.voffset[ii + lr] - y_title_gap + draw_title(title_x[lr], yt, opt.titles[ii + lr], "top") ###########################################