diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 751b7f3..34d718e 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -411,6 +411,7 @@ def setup(self, data): self.x_lr = {} self.nodesize_l = {} self.nodesize_r = {} + self.node_pairs = {} for ii in range(self.num_flow): x_left = ( self.x_node_width + self.x_label_gap + self.x_label_width + ii * (self.sub_width + self.x_node_width) @@ -418,11 +419,15 @@ def setup(self, data): self.x_lr[ii] = (x_left, x_left + self.sub_width) self.nodesize_l[ii] = {} self.nodesize_r[ii] = {} + self.node_pairs[ii] = [] for lbl_l in self.node_list[ii]: self.nodesize_l[ii][lbl_l] = {} self.nodesize_r[ii][lbl_l] = {} for lbl_r in self.node_list[ii + 1]: ind = (self.data[2 * ii] == lbl_l) & (self.data[2 * ii + 2] == lbl_r) + if not any(ind): + continue + self.node_pairs[ii].append((lbl_l,lbl_r)) self.nodesize_l[ii][lbl_l][lbl_r] = self.data[2 * ii + 1][ind].sum() self.nodesize_r[ii][lbl_l][lbl_r] = self.data[2 * ii + 3][ind].sum() @@ -605,68 +610,63 @@ def subplot(self, ii): # Plot flows - for lbl_l in self.node_list[ii]: - for lbl_r in self.node_list[ii + 1]: - lind = self.data[2 * ii] == lbl_l - rind = self.data[2 * ii + 2] == lbl_r - if not any(lind & rind): - continue - - lbot = self.node_pos_voffset[ii][0][lbl_l] + self.node_pos_bot[ii][0][lbl_l] - rbot = self.node_pos_voffset[ii][1][lbl_r] + self.node_pos_bot[ii][1][lbl_r] - llen = self.nodesize_l[ii][lbl_l][lbl_r] - rlen = self.nodesize_r[ii][lbl_l][lbl_r] - bot_lr = [lbot, rbot] - len_lr = [llen, rlen] - - ys_d = self.create_curve(lbot, rbot) - ys_u = self.create_curve(lbot + llen, rbot + rlen) - - # Update bottom edges at each label - # so next strip starts at the right place - self.node_pos_bot[ii][0][lbl_l] += llen - self.node_pos_bot[ii][1][lbl_r] += rlen - - xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) - cc = self.combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d)) - - for jj in range(len(ys_d) - 1): - self.draw_flow( - xx[[jj, jj + 1]], - ys_d[[jj, jj + 1]], - ys_u[[jj, jj + 1]], - cc[:, jj], - ) + for lbl_l, lbl_r in self.node_pairs[ii]: - ha = ["left", "right"] - sides = [] - if ii == 0: - ind = 0 - elif ii == self.num_flow - 1: - ind = 2 - else: - ind = 1 - if self.value_loc[ind] in ("left", "both"): - sides.append(0) - if self.value_loc[ind] in ("right", "both"): - sides.append(1) - for lr in sides: - val = len_lr[lr] - if not ( - val < self.value_thresh_val - or val < self.value_thresh_sum * self.weight_sum[ii + lr] - or val < self.value_thresh_max * max(self.data[2 * ii + 1]) - ): - self.draw_value( - x_lr[lr] + (1 - 2 * lr) * self.x_value_gap, - bot_lr[lr] + len_lr[lr] / 2, - val, - ha[lr], - ) + lbot = self.node_pos_voffset[ii][0][lbl_l] + self.node_pos_bot[ii][0][lbl_l] + rbot = self.node_pos_voffset[ii][1][lbl_r] + self.node_pos_bot[ii][1][lbl_r] + llen = self.nodesize_l[ii][lbl_l][lbl_r] + rlen = self.nodesize_r[ii][lbl_l][lbl_r] + bot_lr = [lbot, rbot] + len_lr = [llen, rlen] + + ys_d = self.create_curve(lbot, rbot) + ys_u = self.create_curve(lbot + llen, rbot + rlen) + + # Update bottom edges at each label + # so next strip starts at the right place + self.node_pos_bot[ii][0][lbl_l] += llen + self.node_pos_bot[ii][1][lbl_r] += rlen + + xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) + cc = self.combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d)) + + for jj in range(len(ys_d) - 1): + self.draw_flow( + xx[[jj, jj + 1]], + ys_d[[jj, jj + 1]], + ys_u[[jj, jj + 1]], + cc[:, jj], + ) + + ha = ["left", "right"] + sides = [] + if ii == 0: + ind = 0 + elif ii == self.num_flow - 1: + ind = 2 + else: + ind = 1 + if self.value_loc[ind] in ("left", "both"): + sides.append(0) + if self.value_loc[ind] in ("right", "both"): + sides.append(1) + for lr in sides: + val = len_lr[lr] + if not ( + val < self.value_thresh_val + or val < self.value_thresh_sum * self.weight_sum[ii + lr] + or val < self.value_thresh_max * max(self.data[2 * ii + 1]) + ): + self.draw_value( + x_lr[lr] + (1 - 2 * lr) * self.x_value_gap, + bot_lr[lr] + len_lr[lr] / 2, + val, + ha[lr], + ) # Place "titles" if self.titles is not None: - last_label = [lbl_l, lbl_r] + last_label = self.node_pairs[ii][-1] title_x = [x_lr[0] - self.x_node_width / 2, x_lr[1] + self.x_node_width / 2] for lr in [0, 1] if ii == 0 else [1]: