From 636ff9bf0b17c99dff3ad1179315a5d6537b67d1 Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Fri, 29 Mar 2024 09:38:31 +1030 Subject: [PATCH] another step with class methods --- ausankey/ausankey.py | 580 +++++++++++++++++++++---------------------- 1 file changed, 288 insertions(+), 292 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 779218d..939f16a 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -12,6 +12,28 @@ import pandas as pd +def sankey(data,**kwargs): + """Make Sankey Diagram + + Parameters + ---------- + **kwargs : function arguments + See the Sankey class for complete list of arguments. + + Returns + ------- + + None (yet) + """ + + opt = Sankey(**kwargs) + opt = opt.setup(data,opt) + + # draw each segment + for ii in range(opt.num_flow): + opt._sankey(ii, data, opt) + + class SankeyError(Exception): pass @@ -320,320 +342,294 @@ def setup(self, data, opt): return opt -def sankey(data,**kwargs): - """Make Sankey Diagram - - Parameters - ---------- - **kwargs : function arguments - See the Sankey class for complete list of arguments. - - Returns - ------- - - None (yet) - """ - - opt = Sankey(**kwargs) - opt = opt.setup(data,opt) - - # draw each segment - for ii in range(opt.num_flow): - _sankey(ii, data, opt) - - - -########################################### - - -def _sankey(ii, data, opt): - """Subroutine for plotting horizontal sections of the Sankey plot - - Some special-casing is used for plotting/labelling differently - for the first and last cases. - """ - - labelind = 2 * ii - weightind = 2 * ii + 1 - - if ii < opt.num_flow - 1: - labels_lr = [ - pd.Series(data[labelind]), - pd.Series(data[labelind + 2]), - pd.Series(data[labelind + 4]), - ] - weights_lr = [ - pd.Series(data[weightind]), - pd.Series(data[weightind + 2]), - pd.Series(data[weightind + 4]), - ] - else: - labels_lr = [ - pd.Series(data[labelind]), - pd.Series(data[labelind + 2]), - pd.Series(data[labelind + 2]), - ] - weights_lr = [ - pd.Series(data[weightind]), - pd.Series(data[weightind + 2]), - pd.Series(data[weightind + 2]), + def _sankey(self,ii, data, opt): + """Subroutine for plotting horizontal sections of the Sankey plot + + Some special-casing is used for plotting/labelling differently + for the first and last cases. + """ + + labelind = 2 * ii + weightind = 2 * ii + 1 + + if ii < opt.num_flow - 1: + labels_lr = [ + pd.Series(data[labelind]), + pd.Series(data[labelind + 2]), + pd.Series(data[labelind + 4]), + ] + weights_lr = [ + pd.Series(data[weightind]), + pd.Series(data[weightind + 2]), + pd.Series(data[weightind + 4]), + ] + else: + labels_lr = [ + pd.Series(data[labelind]), + pd.Series(data[labelind + 2]), + pd.Series(data[labelind + 2]), + ] + weights_lr = [ + pd.Series(data[weightind]), + pd.Series(data[weightind + 2]), + pd.Series(data[weightind + 2]), + ] + + nodes_lr = [ + sort_nodes(labels_lr[0], opt.node_sizes[ii]), + sort_nodes(labels_lr[1], opt.node_sizes[ii + 1]), ] - - nodes_lr = [ - sort_nodes(labels_lr[0], opt.node_sizes[ii]), - sort_nodes(labels_lr[1], opt.node_sizes[ii + 1]), - ] - - # check colours - check_colors_match_labels(labels_lr, opt.color_dict) - - # Determine sizes of individual subflows - nodesize = [{}, {}] - for lbl_l in nodes_lr[0]: - nodesize[0][lbl_l] = {} - nodesize[1][lbl_l] = {} - for lbl_r in nodes_lr[1]: - ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r) - nodesize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum() - nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() - - # Determine vertical positions of nodes - y_node_gap = opt.node_gap * opt.plot_height - - if opt.valign == "top": - vscale = 1 - elif opt.valign == "center": - vscale = 0.5 - else: # bottom, or undefined - vscale = 0 - - node_voffset = [{}, {}] - node_pos_bot = [{}, {}] - node_pos_top = [{}, {}] - - for lr in [0, 1]: - for i, label in enumerate(nodes_lr[lr]): - node_height = opt.node_sizes[ii + lr][label] - this_side_height = weights_lr[lr][labels_lr[lr] == label].sum() - node_voffset[lr][label] = vscale * (node_height - this_side_height) - next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0 - node_pos_bot[lr][label] = opt.voffset[ii + lr] if i == 0 else next_bot - node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height - - # horizontal positions of nodes - x_node_width = opt.node_width * opt.sub_width - x_label_width = opt.label_width * opt.sub_width - x_label_gap = opt.label_gap * opt.sub_width - x_left = x_node_width + x_label_gap + x_label_width + ii * (opt.sub_width + x_node_width) - x_lr = [x_left, x_left + opt.sub_width] - - # Draw nodes - - def draw_node(x, dx, y, dy, label): - edge_lw = opt.node_lw if opt.node_edge else 0 - opt.ax.fill_between( - [x, x + dx], - y, - y + dy, - facecolor=opt.color_dict[label], - alpha=opt.node_alpha, - lw=edge_lw, - snap=True, - ) - if opt.node_edge: + + # check colours + check_colors_match_labels(labels_lr, opt.color_dict) + + # Determine sizes of individual subflows + nodesize = [{}, {}] + for lbl_l in nodes_lr[0]: + nodesize[0][lbl_l] = {} + nodesize[1][lbl_l] = {} + for lbl_r in nodes_lr[1]: + ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r) + nodesize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum() + nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() + + # Determine vertical positions of nodes + y_node_gap = opt.node_gap * opt.plot_height + + if opt.valign == "top": + vscale = 1 + elif opt.valign == "center": + vscale = 0.5 + else: # bottom, or undefined + vscale = 0 + + node_voffset = [{}, {}] + node_pos_bot = [{}, {}] + node_pos_top = [{}, {}] + + for lr in [0, 1]: + for i, label in enumerate(nodes_lr[lr]): + node_height = opt.node_sizes[ii + lr][label] + this_side_height = weights_lr[lr][labels_lr[lr] == label].sum() + node_voffset[lr][label] = vscale * (node_height - this_side_height) + next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0 + node_pos_bot[lr][label] = opt.voffset[ii + lr] if i == 0 else next_bot + node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height + + # horizontal positions of nodes + x_node_width = opt.node_width * opt.sub_width + x_label_width = opt.label_width * opt.sub_width + x_label_gap = opt.label_gap * opt.sub_width + x_left = x_node_width + x_label_gap + x_label_width + ii * (opt.sub_width + x_node_width) + x_lr = [x_left, x_left + opt.sub_width] + + # Draw nodes + + def draw_node(x, dx, y, dy, label): + edge_lw = opt.node_lw if opt.node_edge else 0 opt.ax.fill_between( [x, x + dx], y, y + dy, - edgecolor=opt.color_dict[label], - facecolor="none", + facecolor=opt.color_dict[label], + alpha=opt.node_alpha, lw=edge_lw, snap=True, ) - - for lr in [0, 1] if ii == 0 else [1]: - for label in nodes_lr[lr]: - draw_node( - x_lr[lr] - x_node_width * (1 - lr), - x_node_width, - node_pos_bot[lr][label], - opt.node_sizes[ii + lr][label], - label, - ) - - # Draw node labels - - def draw_label(x, y, label, ha): - opt.ax.text( - x, - y, - opt.label_dict.get(label, label), - { - "ha": ha, - "va": "center", - "fontfamily": opt.fontfamily, - "fontsize": opt.fontsize, - "color": opt.fontcolor, - **opt.label_font, - }, - ) - - ha_dict = {"left": "right", "right": "left", "center": "center"} - - # first row of labels - lr = 0 - if ii == 0 and opt.label_loc[0] != "none": - if opt.label_loc[0] in ("left"): - xx = x_lr[lr] - x_label_gap - x_node_width - elif opt.label_loc[0] in ("right"): - xx = x_lr[lr] + x_label_gap - elif opt.label_loc[0] in ("center"): - xx = x_lr[lr] - x_node_width / 2 - for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha_dict[opt.label_loc[0]]) - - # inside labels, left - lr = 1 - if ii < opt.num_flow - 1 and opt.label_loc[1] in ("left", "both"): - xx = x_lr[lr] - x_label_gap - ha = "right" - for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha) - - # inside labels, center - if ii < opt.num_flow - 1 and opt.label_loc[1] in ("center"): - xx = x_lr[lr] + x_node_width / 2 - ha = "center" - for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha) - - # inside labels, right - if ii < opt.num_flow - 1 and opt.label_loc[1] in ("right", "both"): - xx = x_lr[lr] + x_label_gap + x_node_width - ha = "left" - for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha) - - # last row of labels - if ii == opt.num_flow - 1 and opt.label_loc[2] != "none": - if opt.label_loc[2] in ("left"): - xx = x_lr[lr] - x_label_gap - elif opt.label_loc[2] in ("right"): - xx = x_lr[lr] + x_label_gap + x_node_width - elif opt.label_loc[2] in ("center"): - xx = x_lr[lr] + x_node_width / 2 - for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha_dict[opt.label_loc[2]]) - - # Plot flows - if opt.flow_edge: - edge_lw = opt.flow_lw - edge_alpha = 1 - else: - edge_alpha = opt.flow_alpha - edge_lw = 0 - - for lbl_l in nodes_lr[0]: - for lbl_r in nodes_lr[1]: - lind = labels_lr[0] == lbl_l - rind = labels_lr[1] == lbl_r - if not any(lind & rind): - continue - - lbot = node_voffset[0][lbl_l] + node_pos_bot[0][lbl_l] - rbot = node_voffset[1][lbl_r] + node_pos_bot[1][lbl_r] - llen = nodesize[0][lbl_l][lbl_r] - rlen = nodesize[1][lbl_l][lbl_r] - - ys_d = create_curve(lbot, rbot) - ys_u = create_curve(lbot + llen, rbot + rlen) - - # Update bottom edges at each label - # so next strip starts at the right place - node_pos_bot[0][lbl_l] += llen - node_pos_bot[1][lbl_r] += rlen - - xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) - cc = combine_colours(opt.color_dict[lbl_l], opt.color_dict[lbl_r], len(ys_d)) - - for jj in range(len(ys_d) - 1): + if opt.node_edge: opt.ax.fill_between( - xx[[jj, jj + 1]], - ys_d[[jj, jj + 1]], - ys_u[[jj, jj + 1]], - color=cc[:, jj], - alpha=opt.flow_alpha, - lw=0, - edgecolor="none", - snap=True, - ) - # edges: - opt.ax.plot( - xx[[jj, jj + 1]], - ys_d[[jj, jj + 1]], - color=cc[:, jj], - alpha=edge_alpha, + [x, x + dx], + y, + y + dy, + edgecolor=opt.color_dict[label], + facecolor="none", lw=edge_lw, snap=True, ) - opt.ax.plot( - xx[[jj, jj + 1]], - ys_u[[jj, jj + 1]], - color=cc[:, jj], - alpha=edge_alpha, - lw=edge_lw, - snap=True, + + for lr in [0, 1] if ii == 0 else [1]: + for label in nodes_lr[lr]: + draw_node( + x_lr[lr] - x_node_width * (1 - lr), + x_node_width, + node_pos_bot[lr][label], + opt.node_sizes[ii + lr][label], + label, ) - - # Place "titles" - if opt.titles is not None: - last_label = [lbl_l, lbl_r] - - y_title_gap = opt.title_gap * opt.plot_height - y_frame_gap = opt.frame_gap * opt.plot_height - - title_x = [ - x_lr[0] - x_node_width / 2, - x_lr[1] + x_node_width / 2, - ] - - def draw_title(x, y, label, va): + + # Draw node labels + + def draw_label(x, y, label, ha): opt.ax.text( x, y, - label, + opt.label_dict.get(label, label), { - "ha": "center", - "va": va, + "ha": ha, + "va": "center", "fontfamily": opt.fontfamily, "fontsize": opt.fontsize, "color": opt.fontcolor, - **opt.title_font, + **opt.label_font, }, ) - - # leftmost title - title_lr = [0, 1] if ii == 0 else [1] - - for lr in title_lr: - if opt.title_side in ("top", "both"): - if opt.title_loc == "outer": - yt = min(opt.voffset) + y_title_gap + y_frame_gap + opt.plot_height - elif opt.title_loc == "inner": - yt = y_title_gap + node_pos_top[lr][last_label[lr]] - draw_title(title_x[lr], yt, opt.titles[ii + lr], "bottom") - - if opt.title_side in ("bottom", "both"): - if opt.title_loc == "outer": - yt = min(opt.voffset) - y_title_gap - y_frame_gap - elif opt.title_loc == "inner": - yt = opt.voffset[ii + lr] - y_title_gap - draw_title(title_x[lr], yt, opt.titles[ii + lr], "top") + + ha_dict = {"left": "right", "right": "left", "center": "center"} + + # first row of labels + lr = 0 + if ii == 0 and opt.label_loc[0] != "none": + if opt.label_loc[0] in ("left"): + xx = x_lr[lr] - x_label_gap - x_node_width + elif opt.label_loc[0] in ("right"): + xx = x_lr[lr] + x_label_gap + elif opt.label_loc[0] in ("center"): + xx = x_lr[lr] - x_node_width / 2 + for label in nodes_lr[lr]: + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha_dict[opt.label_loc[0]]) + + # inside labels, left + lr = 1 + if ii < opt.num_flow - 1 and opt.label_loc[1] in ("left", "both"): + xx = x_lr[lr] - x_label_gap + ha = "right" + for label in nodes_lr[lr]: + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha) + + # inside labels, center + if ii < opt.num_flow - 1 and opt.label_loc[1] in ("center"): + xx = x_lr[lr] + x_node_width / 2 + ha = "center" + for label in nodes_lr[lr]: + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha) + + # inside labels, right + if ii < opt.num_flow - 1 and opt.label_loc[1] in ("right", "both"): + xx = x_lr[lr] + x_label_gap + x_node_width + ha = "left" + for label in nodes_lr[lr]: + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha) + + # last row of labels + if ii == opt.num_flow - 1 and opt.label_loc[2] != "none": + if opt.label_loc[2] in ("left"): + xx = x_lr[lr] - x_label_gap + elif opt.label_loc[2] in ("right"): + xx = x_lr[lr] + x_label_gap + x_node_width + elif opt.label_loc[2] in ("center"): + xx = x_lr[lr] + x_node_width / 2 + for label in nodes_lr[lr]: + yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha_dict[opt.label_loc[2]]) + + # Plot flows + if opt.flow_edge: + edge_lw = opt.flow_lw + edge_alpha = 1 + else: + edge_alpha = opt.flow_alpha + edge_lw = 0 + + for lbl_l in nodes_lr[0]: + for lbl_r in nodes_lr[1]: + lind = labels_lr[0] == lbl_l + rind = labels_lr[1] == lbl_r + if not any(lind & rind): + continue + + lbot = node_voffset[0][lbl_l] + node_pos_bot[0][lbl_l] + rbot = node_voffset[1][lbl_r] + node_pos_bot[1][lbl_r] + llen = nodesize[0][lbl_l][lbl_r] + rlen = nodesize[1][lbl_l][lbl_r] + + ys_d = create_curve(lbot, rbot) + ys_u = create_curve(lbot + llen, rbot + rlen) + + # Update bottom edges at each label + # so next strip starts at the right place + node_pos_bot[0][lbl_l] += llen + node_pos_bot[1][lbl_r] += rlen + + xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) + cc = combine_colours(opt.color_dict[lbl_l], opt.color_dict[lbl_r], len(ys_d)) + + for jj in range(len(ys_d) - 1): + opt.ax.fill_between( + xx[[jj, jj + 1]], + ys_d[[jj, jj + 1]], + ys_u[[jj, jj + 1]], + color=cc[:, jj], + alpha=opt.flow_alpha, + lw=0, + edgecolor="none", + snap=True, + ) + # edges: + opt.ax.plot( + xx[[jj, jj + 1]], + ys_d[[jj, jj + 1]], + color=cc[:, jj], + alpha=edge_alpha, + lw=edge_lw, + snap=True, + ) + opt.ax.plot( + xx[[jj, jj + 1]], + ys_u[[jj, jj + 1]], + color=cc[:, jj], + alpha=edge_alpha, + lw=edge_lw, + snap=True, + ) + + # Place "titles" + if opt.titles is not None: + last_label = [lbl_l, lbl_r] + + y_title_gap = opt.title_gap * opt.plot_height + y_frame_gap = opt.frame_gap * opt.plot_height + + title_x = [ + x_lr[0] - x_node_width / 2, + x_lr[1] + x_node_width / 2, + ] + + def draw_title(x, y, label, va): + opt.ax.text( + x, + y, + label, + { + "ha": "center", + "va": va, + "fontfamily": opt.fontfamily, + "fontsize": opt.fontsize, + "color": opt.fontcolor, + **opt.title_font, + }, + ) + + # leftmost title + title_lr = [0, 1] if ii == 0 else [1] + + for lr in title_lr: + if opt.title_side in ("top", "both"): + if opt.title_loc == "outer": + yt = min(opt.voffset) + y_title_gap + y_frame_gap + opt.plot_height + elif opt.title_loc == "inner": + yt = y_title_gap + node_pos_top[lr][last_label[lr]] + draw_title(title_x[lr], yt, opt.titles[ii + lr], "bottom") + + if opt.title_side in ("bottom", "both"): + if opt.title_loc == "outer": + yt = min(opt.voffset) - y_title_gap - y_frame_gap + elif opt.title_loc == "inner": + yt = opt.voffset[ii + lr] - y_title_gap + draw_title(title_x[lr], yt, opt.titles[ii + lr], "top") ###########################################