diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 3f057ad..e727887 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -184,7 +184,7 @@ def sankey( weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) node_sizes[ii] = sort_dict(node_sizes[ii], sort) - + for ii in range(num_side): if ii == 0: ind_prev = data[2 * ii + 1].notnull() @@ -334,18 +334,18 @@ def _sankey( Some special-casing is used for plotting/labelling differently for the first and last cases. """ - + if flow_edge: edge_alpha = 1 edge_lw = 1 else: edge_alpha = alpha edge_lw = 0 - + labelind = 2 * ii weightind = 2 * ii + 1 - if ii < num_flow-1: + if ii < num_flow - 1: labels_lr = [ pd.Series(data[labelind]), pd.Series(data[labelind + 2]), @@ -386,8 +386,8 @@ def _sankey( ] else: bar_lr = [ - sort_nodes(labels_lr[0],node_sizes[ii]), - sort_nodes(labels_lr[1],node_sizes[ii+1]), + sort_nodes(labels_lr[0], node_sizes[ii]), + sort_nodes(labels_lr[1], node_sizes[ii + 1]), ] # check labels @@ -423,7 +423,7 @@ def _sankey( vscale = 0 barpos = [{}, {}] - node_voffset = [{},{}] + node_voffset = [{}, {}] for lr in [0, 1]: for i, label in enumerate(bar_lr[lr]): barpos[lr][label] = {} @@ -509,7 +509,6 @@ def draw_bar(x, dx, y, dy, label): # Plot flows for lbl_l in bar_lr[0]: for lbl_r in bar_lr[1]: - lind = labels_lr[0] == lbl_l rind = labels_lr[1] == lbl_r if not any(lind & rind): @@ -527,10 +526,10 @@ def draw_bar(x, dx, y, dy, label): # so next strip starts at the right place barpos[0][lbl_l]["bot"] += lbar barpos[1][lbl_r]["bot"] += rbar - + xx = np.linspace(x_left, x_right, len(ys_d)) cc = combine_colours(color_dict[lbl_l], color_dict[lbl_r], len(ys_d)) - + for jj in range(len(ys_d) - 1): ax.fill_between( xx[[jj, jj + 1]], @@ -602,8 +601,10 @@ def draw_bar(x, dx, y, dy, label): fontsize=fontsize, ) + ########################################### + def sort_nodes(lbl, node_sizes): """creates a sorted list of labels by their summed weights""" @@ -618,9 +619,11 @@ def sort_nodes(lbl, node_sizes): ) return list(dict(sort_arr).keys()) - + + ########################################### + def sort_dict(lbl, sorting): """creates a sorted list of labels by their summed weights""" @@ -633,7 +636,6 @@ def sort_dict(lbl, sorting): else: s = 0 - sort_arr = sorted( lbl.items(), key=lambda item: s * item[1],