From 92efe900b2d8a0e5d6305dd4b2a4b990f4030b9d Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Fri, 22 Mar 2024 21:36:31 +1030 Subject: [PATCH] precalculate node sizes --- ausankey/ausankey.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 228ee8b..c891c00 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -163,20 +163,28 @@ def sankey( nodes_uniq[ii] = pd.Series(data[2 * ii]).unique() num_uniq[ii] = len(nodes_uniq[ii]) - # for ii in range(num_side): - # node_sizes[ii] = {} - # for lbl in nodes_uniq[ii]: - # if ii == 0: - # ind_prev = data[2 * ii + 0] == lbl - # ind_this = data[2 * ii + 0] == lbl - # ind_next = data[2 * ii + 2] == lbl - # weight_cont = data[2 * ii + 1][ind_this & ind_prev & ind_next].sum() - # weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum() - # weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum() - # weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() - # node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) - - # print(node_sizes) + for ii in range(num_side): + node_sizes[ii] = {} + for lbl in nodes_uniq[ii]: + if ii == 0: + ind_prev = data[2 * ii + 0] == lbl + ind_this = data[2 * ii + 0] == lbl + ind_next = data[2 * ii + 2] == lbl + elif ii == num_flow: + ind_prev = data[2 * ii - 2] == lbl + ind_this = data[2 * ii + 0] == lbl + ind_next = data[2 * ii + 0] == lbl + else: + ind_prev = data[2 * ii - 2] == lbl + ind_this = data[2 * ii + 0] == lbl + ind_next = data[2 * ii + 2] == lbl + weight_cont = data[2 * ii + 1][ind_this & ind_prev & ind_next].sum() + weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum() + weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum() + weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() + node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) + node_sizes[ii] = sort_dict(node_sizes[ii], sort) + for ii in range(num_side): if ii == 0: ind_prev = data[2 * ii + 1].notnull() @@ -265,6 +273,7 @@ def sankey( ii, num_flow, data, + node_sizes=node_sizes, titles=titles, title_gap=title_gap, title_side=title_side, @@ -303,6 +312,7 @@ def _sankey( flow_edge=None, fontsize=None, frame_gap=None, + node_sizes=None, titles=None, title_gap=None, title_side=None, @@ -595,12 +605,10 @@ def weighted_sort(lbl, wgt, sorting): else: s = 0 - arr = {} for uniq in lbl.unique(): arr[uniq] = wgt[lbl == uniq].sum() sort_arr = sorted( - arr.items(), key=lambda item: s * item[1], # sorting = 0,1,-1 affects this )