From 161b50a477e08fe52f667b11ecc07c20aaff4c93 Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Tue, 16 Apr 2024 20:08:01 +0930 Subject: [PATCH] tighten up code and bug fix! --- ausankey/ausankey.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 7f7a9fe..ee9e074 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -358,7 +358,6 @@ def setup(self, data): self.num_flow = self.num_stages - 1 # sizes - col_hgt = np.empty(self.num_stages) self.node_sizes = {} self.node_list = {} self.nodes_uniq = {} @@ -367,32 +366,38 @@ def setup(self, data): self.node_pos_bot = {} self.node_pos_top = {} + for ii in range(self.num_stages): + uni = pd.Series(self.data[2 * ii]).unique() + self.nodes_uniq[ii] = pd.Series(uni).dropna() + # weight and reclassify self.weight_labels() for ii in range(self.num_stages): for nn, lbl in enumerate(self.data[2 * ii]): - val = self.node_sizes[ii][lbl] - if lbl is not None and ( - val < self.other_thresh_val - or val < self.other_thresh_sum * self.weight_sum[ii] - or val < self.other_thresh_max * max(self.data[2 * ii + 1]) - ): - self.data.iat[nn, 2 * ii] = self.other_name + if lbl is not None: + val = self.node_sizes[ii][lbl] + if ( + val < self.other_thresh_val + or val < self.other_thresh_sum * self.weight_sum[ii] + or val < self.other_thresh_max * max(self.data[2 * ii + 1]) + ): + self.data.iat[nn, 2 * ii] = self.other_name self.weight_labels() # sort and calc self.plot_height_nom = max(self.weight_sum) for ii in range(self.num_stages): self.node_sizes[ii] = self.sort_node_sizes(self.node_sizes[ii], self.sort) - col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * self.plot_height_nom self.node_list[ii] = self.sort_nodes(self.data[2 * ii], self.node_sizes[ii]) # offsets for alignment vscale_dict = {"top": 1, "center": 0.5, "bottom": 0} self.vscale = vscale_dict.get(self.valign, 0) self.voffset = np.empty(self.num_stages) + col_hgt = np.empty(self.num_stages) for ii in range(self.num_stages): - self.voffset[ii] = self.vscale * (col_hgt[1] - col_hgt[ii]) + col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * self.plot_height_nom + self.voffset[ii] = self.vscale * (col_hgt[0] - col_hgt[ii]) # overall dimensions self.plot_height = max(col_hgt) @@ -482,9 +487,6 @@ def weight_labels(self): """Calculates sizes of each node, taking into account discontinuities""" self.weight_sum = np.empty(self.num_stages) - for ii in range(self.num_stages): - self.nodes_uniq[ii] = pd.Series(self.data[2 * ii]).unique() - for ii in range(self.num_stages): self.node_sizes[ii] = {} for lbl in self.nodes_uniq[ii]: