Skip to content

Commit

Permalink
refactor out one local variable
Browse files Browse the repository at this point in the history
(Want to -- eventually -- precalculate all params before heading into
the plotting routines)
  • Loading branch information
wspr committed Apr 9, 2024
1 parent bec9252 commit 169de5f
Showing 1 changed file with 19 additions and 23 deletions.
42 changes: 19 additions & 23 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def setup(self, data):
# sizes
col_hgt = np.empty(self.num_stages)
self.node_sizes = {}
self.node_list = {}
self.nodes_uniq = {}

# weight and reclassify
Expand All @@ -373,6 +374,7 @@ def setup(self, data):
for ii in range(self.num_stages):
self.node_sizes[ii] = self.sort_node_sizes(self.node_sizes[ii], self.sort)
col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * self.plot_height_nom
self.node_list[ii] = self.sort_nodes(self.data[2 * ii], self.node_sizes[ii])

# overall dimensions
self.plot_height = max(col_hgt)
Expand Down Expand Up @@ -487,12 +489,6 @@ def subplot(self, ii):
weights_lr = [
self.data[2 * ii + 1],
self.data[2 * ii + 1 + 2],
self.data[2 * ii + 1 + lastind],
]

nodes_lr = [
self.sort_nodes(labels_lr[0], self.node_sizes[ii]),
self.sort_nodes(labels_lr[1], self.node_sizes[ii + 1]),
]

# vertical positions
Expand All @@ -515,27 +511,27 @@ def subplot(self, ii):
node_pos_top = [{}, {}]
nodesize = [{}, {}]

for lbl_l in nodes_lr[0]:
for lbl_l in self.node_list[ii]:
nodesize[0][lbl_l] = {}
nodesize[1][lbl_l] = {}
for lbl_r in nodes_lr[1]:
for lbl_r in self.node_list[ii + 1]:
ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r)
nodesize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum()
nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum()
nodesize[0][lbl_l][lbl_r] = self.data[2 * ii + 1][ind].sum()
nodesize[1][lbl_l][lbl_r] = self.data[2 * ii + 3][ind].sum()

for lr in [0, 1]:
for i, label in enumerate(nodes_lr[lr]):
for i, label in enumerate(self.node_list[ii + lr]):
node_height = self.node_sizes[ii + lr][label]
this_side_height = weights_lr[lr][labels_lr[lr] == label].sum()
node_voffset[lr][label] = self.vscale * (node_height - this_side_height)
next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0
next_bot = node_pos_top[lr][self.node_list[ii + lr][i - 1]] + y_node_gap if i > 0 else 0
node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot
node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height

# Draw nodes

for lr in [0, 1] if ii == 0 else [1]:
for label in nodes_lr[lr]:
for label in self.node_list[ii + lr]:
self.draw_node(
x_lr[lr] - x_node_width * (1 - lr),
x_node_width,
Expand All @@ -557,7 +553,7 @@ def subplot(self, ii):
xx = x_lr[lr] + x_label_gap
elif self.label_loc[0] in ("center"):
xx = x_lr[lr] - x_node_width / 2
for label in nodes_lr[lr]:
for label in self.node_list[ii + lr]:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha_dict[self.label_loc[0]])

Expand All @@ -566,26 +562,26 @@ def subplot(self, ii):
if ii < self.num_flow - 1 and self.label_loc[1] in ("left", "both"):
xx = x_lr[lr] - x_label_gap
ha = "right"
for label in nodes_lr[lr]:
if (label not in nodes_lr[lr - 1]) or self.label_duplicate:
for label in self.node_list[ii + lr]:
if (label not in self.node_list[ii]) or self.label_duplicate:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha)

# inside labels, center
if ii < self.num_flow - 1 and self.label_loc[1] in ("center"):
xx = x_lr[lr] + x_node_width / 2
ha = "center"
for label in nodes_lr[lr]:
if (label not in nodes_lr[lr - 1]) or self.label_duplicate:
for label in self.node_list[ii + lr]:
if (label not in self.node_list[ii]) or self.label_duplicate:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha)

# inside labels, right
if ii < self.num_flow - 1 and self.label_loc[1] in ("right", "both"):
xx = x_lr[lr] + x_label_gap + x_node_width
ha = "left"
for label in nodes_lr[lr]:
if (label not in nodes_lr[lr - 1]) or self.label_duplicate:
for label in self.node_list[ii + lr]:
if (label not in self.node_list[ii]) or self.label_duplicate:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha)

Expand All @@ -597,14 +593,14 @@ def subplot(self, ii):
xx = x_lr[lr] + x_label_gap + x_node_width
elif self.label_loc[2] in ("center"):
xx = x_lr[lr] + x_node_width / 2
for label in nodes_lr[lr]:
for label in self.node_list[ii + lr]:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha_dict[self.label_loc[2]])

# Plot flows

for lbl_l in nodes_lr[0]:
for lbl_r in nodes_lr[1]:
for lbl_l in self.node_list[ii]:
for lbl_r in self.node_list[ii + 1]:
lind = labels_lr[0] == lbl_l
rind = labels_lr[1] == lbl_r
if not any(lind & rind):
Expand Down

0 comments on commit 169de5f

Please sign in to comment.