diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 37dbfaa..0e295be 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -353,6 +353,7 @@ def setup(self, data): # sizes col_hgt = np.empty(self.num_stages) self.node_sizes = {} + self.node_list = {} self.nodes_uniq = {} # weight and reclassify @@ -373,6 +374,7 @@ def setup(self, data): for ii in range(self.num_stages): self.node_sizes[ii] = self.sort_node_sizes(self.node_sizes[ii], self.sort) col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * self.plot_height_nom + self.node_list[ii] = self.sort_nodes(self.data[2 * ii], self.node_sizes[ii]) # overall dimensions self.plot_height = max(col_hgt) @@ -487,12 +489,6 @@ def subplot(self, ii): weights_lr = [ self.data[2 * ii + 1], self.data[2 * ii + 1 + 2], - self.data[2 * ii + 1 + lastind], - ] - - nodes_lr = [ - self.sort_nodes(labels_lr[0], self.node_sizes[ii]), - self.sort_nodes(labels_lr[1], self.node_sizes[ii + 1]), ] # vertical positions @@ -515,27 +511,27 @@ def subplot(self, ii): node_pos_top = [{}, {}] nodesize = [{}, {}] - for lbl_l in nodes_lr[0]: + for lbl_l in self.node_list[ii]: nodesize[0][lbl_l] = {} nodesize[1][lbl_l] = {} - for lbl_r in nodes_lr[1]: + for lbl_r in self.node_list[ii + 1]: ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r) - nodesize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum() - nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() + nodesize[0][lbl_l][lbl_r] = self.data[2 * ii + 1][ind].sum() + nodesize[1][lbl_l][lbl_r] = self.data[2 * ii + 3][ind].sum() for lr in [0, 1]: - for i, label in enumerate(nodes_lr[lr]): + for i, label in enumerate(self.node_list[ii + lr]): node_height = self.node_sizes[ii + lr][label] this_side_height = weights_lr[lr][labels_lr[lr] == label].sum() node_voffset[lr][label] = self.vscale * (node_height - this_side_height) - next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0 + next_bot = node_pos_top[lr][self.node_list[ii + lr][i - 1]] + y_node_gap if i > 0 else 0 node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height # Draw nodes for lr in [0, 1] if ii == 0 else [1]: - for label in nodes_lr[lr]: + for label in self.node_list[ii + lr]: self.draw_node( x_lr[lr] - x_node_width * (1 - lr), x_node_width, @@ -557,7 +553,7 @@ def subplot(self, ii): xx = x_lr[lr] + x_label_gap elif self.label_loc[0] in ("center"): xx = x_lr[lr] - x_node_width / 2 - for label in nodes_lr[lr]: + for label in self.node_list[ii + lr]: yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 self.draw_label(xx, yy, label, ha_dict[self.label_loc[0]]) @@ -566,8 +562,8 @@ def subplot(self, ii): if ii < self.num_flow - 1 and self.label_loc[1] in ("left", "both"): xx = x_lr[lr] - x_label_gap ha = "right" - for label in nodes_lr[lr]: - if (label not in nodes_lr[lr - 1]) or self.label_duplicate: + for label in self.node_list[ii + lr]: + if (label not in self.node_list[ii]) or self.label_duplicate: yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 self.draw_label(xx, yy, label, ha) @@ -575,8 +571,8 @@ def subplot(self, ii): if ii < self.num_flow - 1 and self.label_loc[1] in ("center"): xx = x_lr[lr] + x_node_width / 2 ha = "center" - for label in nodes_lr[lr]: - if (label not in nodes_lr[lr - 1]) or self.label_duplicate: + for label in self.node_list[ii + lr]: + if (label not in self.node_list[ii]) or self.label_duplicate: yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 self.draw_label(xx, yy, label, ha) @@ -584,8 +580,8 @@ def subplot(self, ii): if ii < self.num_flow - 1 and self.label_loc[1] in ("right", "both"): xx = x_lr[lr] + x_label_gap + x_node_width ha = "left" - for label in nodes_lr[lr]: - if (label not in nodes_lr[lr - 1]) or self.label_duplicate: + for label in self.node_list[ii + lr]: + if (label not in self.node_list[ii]) or self.label_duplicate: yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 self.draw_label(xx, yy, label, ha) @@ -597,14 +593,14 @@ def subplot(self, ii): xx = x_lr[lr] + x_label_gap + x_node_width elif self.label_loc[2] in ("center"): xx = x_lr[lr] + x_node_width / 2 - for label in nodes_lr[lr]: + for label in self.node_list[ii + lr]: yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 self.draw_label(xx, yy, label, ha_dict[self.label_loc[2]]) # Plot flows - for lbl_l in nodes_lr[0]: - for lbl_r in nodes_lr[1]: + for lbl_l in self.node_list[ii]: + for lbl_r in self.node_list[ii + 1]: lind = labels_lr[0] == lbl_l rind = labels_lr[1] == lbl_r if not any(lind & rind):