diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 6309acf..8faf73f 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -480,13 +480,6 @@ def subplot(self, ii): for the first and last cases. """ - lastind = 4 if ii < self.num_flow - 1 else 2 - labels_lr = [ - self.data[2 * ii], - self.data[2 * ii + 2], - self.data[2 * ii + lastind], - ] - # vertical positions y_node_gap = self.node_gap * self.plot_height_nom y_title_gap = self.title_gap * self.plot_height_nom @@ -511,14 +504,14 @@ def subplot(self, ii): nodesize[0][lbl_l] = {} nodesize[1][lbl_l] = {} for lbl_r in self.node_list[ii + 1]: - ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r) + ind = (self.data[2 * ii] == lbl_l) & (self.data[2 * ii + 2] == lbl_r) nodesize[0][lbl_l][lbl_r] = self.data[2 * ii + 1][ind].sum() nodesize[1][lbl_l][lbl_r] = self.data[2 * ii + 3][ind].sum() for lr in [0, 1]: for i, label in enumerate(self.node_list[ii + lr]): node_height = self.node_sizes[ii + lr][label] - this_side_height = self.data[2 * (ii + lr) + 1][labels_lr[lr] == label].sum() + this_side_height = self.data[2 * (ii + lr) + 1][self.data[2 * (ii + lr)] == label].sum() node_voffset[lr][label] = self.vscale * (node_height - this_side_height) next_bot = node_pos_top[lr][self.node_list[ii + lr][i - 1]] + y_node_gap if i > 0 else 0 node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot @@ -597,8 +590,8 @@ def subplot(self, ii): for lbl_l in self.node_list[ii]: for lbl_r in self.node_list[ii + 1]: - lind = labels_lr[0] == lbl_l - rind = labels_lr[1] == lbl_r + lind = self.data[2 * ii] == lbl_l + rind = self.data[2 * ii + 2] == lbl_r if not any(lind & rind): continue