diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 5bc1d49..ee2c192 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -478,19 +478,16 @@ def subplot(self, ii): for the first and last cases. """ - labelind = 2 * ii - weightind = 2 * ii + 1 - lastind = 4 if ii < self.num_flow - 1 else 2 labels_lr = [ - self.data[labelind], - self.data[labelind + 2], - self.data[labelind + lastind], + self.data[2 * ii], + self.data[2 * ii + 2], + self.data[2 * ii + lastind], ] weights_lr = [ - self.data[weightind], - self.data[weightind + 2], - self.data[weightind + lastind], + self.data[2 * ii + 1], + self.data[2 * ii + 1 + 2], + self.data[2 * ii + 1 + lastind], ] nodes_lr = [ @@ -498,8 +495,26 @@ def subplot(self, ii): self.sort_nodes(labels_lr[1], self.node_sizes[ii + 1]), ] - # Determine sizes of individual subflows + # vertical positions + y_node_gap = self.node_gap * self.plot_height_nom + y_title_gap = self.title_gap * self.plot_height_nom + y_frame_gap = self.frame_gap * self.plot_height_nom + + # horizontal positions + x_node_width = self.node_width * self.plot_width_nom + x_label_width = self.label_width * self.plot_width_nom + x_label_gap = self.label_gap * self.plot_width_nom + x_value_gap = self.value_gap * self.plot_width_nom + x_left = x_node_width + x_label_gap + x_label_width + ii * (self.sub_width + x_node_width) + x_lr = [x_left, x_left + self.sub_width] + + # All node sizes and positions + + node_voffset = [{}, {}] + node_pos_bot = [{}, {}] + node_pos_top = [{}, {}] nodesize = [{}, {}] + for lbl_l in nodes_lr[0]: nodesize[0][lbl_l] = {} nodesize[1][lbl_l] = {} @@ -508,13 +523,6 @@ def subplot(self, ii): nodesize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum() nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() - # Determine vertical positions of nodes - y_node_gap = self.node_gap * self.plot_height_nom - - node_voffset = [{}, {}] - node_pos_bot = [{}, {}] - node_pos_top = [{}, {}] - for lr in [0, 1]: for i, label in enumerate(nodes_lr[lr]): node_height = self.node_sizes[ii + lr][label] @@ -524,14 +532,6 @@ def subplot(self, ii): node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height - # horizontal positions of nodes - x_node_width = self.node_width * self.plot_width_nom - x_label_width = self.label_width * self.plot_width_nom - x_label_gap = self.label_gap * self.plot_width_nom - x_value_gap = self.value_gap * self.plot_width_nom - x_left = x_node_width + x_label_gap + x_label_width + ii * (self.sub_width + x_node_width) - x_lr = [x_left, x_left + self.sub_width] - # Draw nodes for lr in [0, 1] if ii == 0 else [1]: @@ -673,8 +673,6 @@ def subplot(self, ii): # Place "titles" if self.titles is not None: last_label = [lbl_l, lbl_r] - y_title_gap = self.title_gap * self.plot_height_nom - y_frame_gap = self.frame_gap * self.plot_height_nom title_x = [x_lr[0] - x_node_width / 2, x_lr[1] + x_node_width / 2] for lr in [0, 1] if ii == 0 else [1]: