From f8b7867d6fbff21776a4ced9517afa10006f5118 Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Sun, 31 Mar 2024 00:00:40 +1030 Subject: [PATCH] add other_dict option I don't love this interface, it might change --- ausankey/ausankey.py | 121 +++++++++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 39 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 3939abf..fcb62a6 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -32,7 +32,7 @@ def sankey(data, **kwargs): # draw each segment for ii in range(sky.num_flow): - sky.subplot(ii, data) + sky.subplot(ii) class SankeyError(Exception): @@ -124,6 +124,7 @@ class Sankey: label_loc : [str1, str2, str3] Position to place labels next to the nodes. + * `str1`: position of first labels (`"left"`, `"right"`, `"center"`, or `"none"`) * `str2`: position of middle labels (`"left"`, `"right"`, `"both"`, `"center"`, or `"none"`) * `str3`: position of last labels (`"left"`, `"right"`, `"center"`, or `"none"`) @@ -133,6 +134,19 @@ class Sankey: appear in the previous stage. This minimises chart clutter but might be confusing in cases, hence defaulting to True. + other_dict : dict + Sets thresholds to recategorise nodes that are below a certain value. + Up to three dictionary keys can be set: + + * `"val": v` — set node to other if it is less than `v` + * `"sum": s` — set node to other if it is less than `s` fraction of the summed total of all nodes in the current stage + * `"max": m` — set node to other if is is less than `m` fraction of the maximum node in the current stage + + If any of these criteria are met the reclassification will occur. + + other_name : str + The string used to rename nodes to if they are classified as “other”. + sort : int Sorting routine to use for the data. * `"top"`: data is sorted with largest entries on top @@ -186,6 +200,8 @@ def __init__( flow_lw=1, node_lw=1, frame_lw=1, + other_dict=None, + other_name="Other", titles=None, title_gap=0.05, title_side="top", # "bottom", "both" @@ -218,6 +234,8 @@ def __init__( self.flow_lw = flow_lw self.node_lw = node_lw self.frame_lw = frame_lw + self.other_name = other_name + self.other_dict = other_dict or {} self.titles = titles self.title_font = title_font or {"fontweight": "bold"} self.title_gap = title_gap @@ -226,47 +244,71 @@ def __init__( self.sort = sort self.valign = valign - def setup(self, data): - num_col = len(data.columns) - data.columns = range(num_col) # force numeric column headings - self.num_stages = int(num_col / 2) # number of stages - self.num_flow = self.num_stages - 1 - # sizes - weight_sum = np.empty(self.num_stages) - num_uniq = np.empty(self.num_stages) - col_hgt = np.empty(self.num_stages) - self.node_sizes = {} - nodes_uniq = {} + def weight_labels(self): + + self.weight_sum = np.empty(self.num_stages) + for ii in range(self.num_stages): - nodes_uniq[ii] = pd.Series(data[2 * ii]).unique() - num_uniq[ii] = len(nodes_uniq[ii]) + self.nodes_uniq[ii] = pd.Series(self.data[2 * ii]).unique() for ii in range(self.num_stages): self.node_sizes[ii] = {} - for lbl in nodes_uniq[ii]: + for lbl in self.nodes_uniq[ii]: if ii == 0: - ind_prev = data[2 * ii - 0] == lbl - ind_this = data[2 * ii + 0] == lbl - ind_next = data[2 * ii + 2] == lbl + ind_prev = self.data[2 * ii - 0] == lbl + ind_this = self.data[2 * ii + 0] == lbl + ind_next = self.data[2 * ii + 2] == lbl elif ii == self.num_flow: - ind_prev = data[2 * ii - 2] == lbl - ind_this = data[2 * ii + 0] == lbl - ind_next = data[2 * ii + 0] == lbl + ind_prev = self.data[2 * ii - 2] == lbl + ind_this = self.data[2 * ii + 0] == lbl + ind_next = self.data[2 * ii + 0] == lbl else: - ind_prev = data[2 * ii - 2] == lbl - ind_this = data[2 * ii + 0] == lbl - ind_next = data[2 * ii + 2] == lbl - weight_cont = data[2 * ii + 1][ind_this & ind_prev & ind_next].sum() - weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum() - weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum() - weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() + ind_prev = self.data[2 * ii - 2] == lbl + ind_this = self.data[2 * ii + 0] == lbl + ind_next = self.data[2 * ii + 2] == lbl + weight_cont = self.data[2 * ii + 1][ind_this & ind_prev & ind_next].sum() + weight_only = self.data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum() + weight_stop = self.data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum() + weight_strt = self.data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() self.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) - self.node_sizes[ii] = sort_dict(self.node_sizes[ii], self.sort) - weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum() + self.weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum() + + + def setup(self, data): + + self.data = data + + num_col = len(self.data.columns) + self.data.columns = range(num_col) # force numeric column headings + self.num_stages = int(num_col / 2) # number of stages + self.num_flow = self.num_stages - 1 + + # sizes + col_hgt = np.empty(self.num_stages) + self.node_sizes = {} + self.nodes_uniq = {} + + self.weight_labels() + + # reclassify + thresh_val = self.other_dict.get("val",0) + thresh_max = self.other_dict.get("max",0) + thresh_sum = self.other_dict.get("sum",0) for ii in range(self.num_stages): - col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * self.node_gap * max(weight_sum) + for nn, lbl in enumerate(self.data[2 * ii]): + val = self.node_sizes[ii][lbl] + if lbl is None: + continue + if val < thresh_val or val < thresh_sum * self.weight_sum[ii] or val < thresh_max * max(self.data[2 * ii + 1]): + self.data.iat[nn,2 * ii] = self.other_name + self.weight_labels() + + # sort and calc + for ii in range(self.num_stages): + self.node_sizes[ii] = sort_dict(self.node_sizes[ii], self.sort) + col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * max(self.weight_sum) # overall dimensions self.plot_height = max(col_hgt) @@ -285,7 +327,7 @@ def setup(self, data): self.voffset[ii] = self.vscale * (col_hgt[1] - col_hgt[ii]) # labels - label_record = data[range(0, 2 * self.num_stages, 2)].to_records(index=False) + label_record = self.data[range(0, 2 * self.num_stages, 2)].to_records(index=False) flattened = [item for sublist in label_record for item in sublist] self.all_labels = pd.Series(flattened).unique() @@ -302,6 +344,7 @@ def setup(self, data): self.ax = self.ax or plt.gca() self.ax.axis("off") + def plot_frame(self): """Plot frame on top/bottom edges""" @@ -328,7 +371,7 @@ def plot_frame(self): lw=self.frame_lw, ) - def subplot(self, ii, data): + def subplot(self, ii): """Subroutine for plotting horizontal sections of the Sankey plot Some special-casing is used for plotting/labelling differently @@ -340,14 +383,14 @@ def subplot(self, ii, data): lastind = 4 if ii < self.num_flow - 1 else 2 labels_lr = [ - data[labelind], - data[labelind + 2], - data[labelind + lastind], + self.data[labelind], + self.data[labelind + 2], + self.data[labelind + lastind], ] weights_lr = [ - data[weightind], - data[weightind + 2], - data[weightind + lastind], + self.data[weightind], + self.data[weightind + 2], + self.data[weightind + lastind], ] nodes_lr = [