Skip to content

Commit

Permalink
one last refactor for now
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Apr 9, 2024
1 parent e059a7d commit 2ce690c
Showing 1 changed file with 44 additions and 34 deletions.
78 changes: 44 additions & 34 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,27 @@ def setup(self, data):
+ self.num_stages * self.plot_width_nom * self.node_width
)

# vertical positions
self.y_node_gap = self.node_gap * self.plot_height_nom
self.y_title_gap = self.title_gap * self.plot_height_nom
self.y_frame_gap = self.frame_gap * self.plot_height_nom

# horizontal positions
self.x_node_width = self.node_width * self.plot_width_nom
self.x_label_width = self.label_width * self.plot_width_nom
self.x_label_gap = self.label_gap * self.plot_width_nom
self.x_value_gap = self.value_gap * self.plot_width_nom

self.x_lr = {}
for ii in range(self.num_flow):
x_left = (
+ self.x_node_width
+ self.x_label_gap
+ self.x_label_width
+ ii * (self.sub_width + self.x_node_width)
)
self.x_lr[ii] = (x_left, x_left + self.sub_width)

# offsets for alignment
vscale_dict = {"top": 1, "center": 0.5, "bottom": 0}
self.vscale = vscale_dict.get(self.valign, 0)
Expand Down Expand Up @@ -453,20 +474,20 @@ def plot_frame(self):

frame_color = self.frame_color or [0, 0, 0, 1]

y_frame_gap = self.frame_gap * self.plot_height
self.y_frame_gap = self.frame_gap * self.plot_height

col = frame_color if frame_top else [1, 1, 1, 0]
self.ax.plot(
[0, self.plot_width],
min(self.voffset) + (self.plot_height) + y_frame_gap + [0, 0],
min(self.voffset) + (self.plot_height) + self.y_frame_gap + [0, 0],
color=col,
lw=self.frame_lw,
)

col = frame_color if frame_bot else [1, 1, 1, 0]
self.ax.plot(
[0, self.plot_width],
min(self.voffset) - y_frame_gap + [0, 0],
min(self.voffset) - self.y_frame_gap + [0, 0],
color=col,
lw=self.frame_lw,
)
Expand All @@ -480,21 +501,10 @@ def subplot(self, ii):
for the first and last cases.
"""

# vertical positions
y_node_gap = self.node_gap * self.plot_height_nom
y_title_gap = self.title_gap * self.plot_height_nom
y_frame_gap = self.frame_gap * self.plot_height_nom

# horizontal positions
x_node_width = self.node_width * self.plot_width_nom
x_label_width = self.label_width * self.plot_width_nom
x_label_gap = self.label_gap * self.plot_width_nom
x_value_gap = self.value_gap * self.plot_width_nom
x_left = x_node_width + x_label_gap + x_label_width + ii * (self.sub_width + x_node_width)
x_lr = [x_left, x_left + self.sub_width]

# All node sizes and positions

x_lr = self.x_lr[ii]

node_voffset = [{}, {}]
node_pos_bot = [{}, {}]
node_pos_top = [{}, {}]
Expand All @@ -513,7 +523,7 @@ def subplot(self, ii):
node_height = self.node_sizes[ii + lr][label]
this_side_height = self.data[2 * (ii + lr) + 1][self.data[2 * (ii + lr)] == label].sum()
node_voffset[lr][label] = self.vscale * (node_height - this_side_height)
next_bot = node_pos_top[lr][self.node_list[ii + lr][i - 1]] + y_node_gap if i > 0 else 0
next_bot = node_pos_top[lr][self.node_list[ii + lr][i - 1]] + self.y_node_gap if i > 0 else 0
node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot
node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height

Expand All @@ -522,8 +532,8 @@ def subplot(self, ii):
for lr in [0, 1] if ii == 0 else [1]:
for label in self.node_list[ii + lr]:
self.draw_node(
x_lr[lr] - x_node_width * (1 - lr),
x_node_width,
x_lr[lr] - self.x_node_width * (1 - lr),
self.x_node_width,
node_pos_bot[lr][label],
self.node_sizes[ii + lr][label],
label,
Expand All @@ -537,19 +547,19 @@ def subplot(self, ii):
lr = 0
if ii == 0 and self.label_loc[0] != "none":
if self.label_loc[0] in ("left"):
xx = x_lr[lr] - x_label_gap - x_node_width
xx = x_lr[lr] - self.x_label_gap - self.x_node_width
elif self.label_loc[0] in ("right"):
xx = x_lr[lr] + x_label_gap
xx = x_lr[lr] + self.x_label_gap
elif self.label_loc[0] in ("center"):
xx = x_lr[lr] - x_node_width / 2
xx = x_lr[lr] - self.x_node_width / 2
for label in self.node_list[ii + lr]:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha_dict[self.label_loc[0]])

# inside labels, left
lr = 1
if ii < self.num_flow - 1 and self.label_loc[1] in ("left", "both"):
xx = x_lr[lr] - x_label_gap
xx = x_lr[lr] - self.x_label_gap
ha = "right"
for label in self.node_list[ii + lr]:
if (label not in self.node_list[ii]) or self.label_duplicate:
Expand All @@ -558,7 +568,7 @@ def subplot(self, ii):

# inside labels, center
if ii < self.num_flow - 1 and self.label_loc[1] in ("center"):
xx = x_lr[lr] + x_node_width / 2
xx = x_lr[lr] + self.x_node_width / 2
ha = "center"
for label in self.node_list[ii + lr]:
if (label not in self.node_list[ii]) or self.label_duplicate:
Expand All @@ -567,7 +577,7 @@ def subplot(self, ii):

# inside labels, right
if ii < self.num_flow - 1 and self.label_loc[1] in ("right", "both"):
xx = x_lr[lr] + x_label_gap + x_node_width
xx = x_lr[lr] + self.x_label_gap + self.x_node_width
ha = "left"
for label in self.node_list[ii + lr]:
if (label not in self.node_list[ii]) or self.label_duplicate:
Expand All @@ -577,11 +587,11 @@ def subplot(self, ii):
# last row of labels
if ii == self.num_flow - 1 and self.label_loc[2] != "none":
if self.label_loc[2] in ("left"):
xx = x_lr[lr] - x_label_gap
xx = x_lr[lr] - self.x_label_gap
elif self.label_loc[2] in ("right"):
xx = x_lr[lr] + x_label_gap + x_node_width
xx = x_lr[lr] + self.x_label_gap + self.x_node_width
elif self.label_loc[2] in ("center"):
xx = x_lr[lr] + x_node_width / 2
xx = x_lr[lr] + self.x_node_width / 2
for label in self.node_list[ii + lr]:
yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2
self.draw_label(xx, yy, label, ha_dict[self.label_loc[2]])
Expand Down Expand Up @@ -642,7 +652,7 @@ def subplot(self, ii):
):
continue
self.ax.text(
x_lr[lr] + (1 - 2 * lr) * x_value_gap,
x_lr[lr] + (1 - 2 * lr) * self.x_value_gap,
bot_lr[lr] + len_lr[lr] / 2,
f"{format(val,self.value_format)}",
{
Expand All @@ -658,21 +668,21 @@ def subplot(self, ii):
# Place "titles"
if self.titles is not None:
last_label = [lbl_l, lbl_r]
title_x = [x_lr[0] - x_node_width / 2, x_lr[1] + x_node_width / 2]
title_x = [x_lr[0] - self.x_node_width / 2, x_lr[1] + self.x_node_width / 2]

for lr in [0, 1] if ii == 0 else [1]:
if self.title_side in ("top", "both"):
if self.title_loc == "outer":
yt = min(self.voffset) + y_title_gap + y_frame_gap + self.plot_height
yt = min(self.voffset) + self.y_title_gap + self.y_frame_gap + self.plot_height
elif self.title_loc == "inner":
yt = y_title_gap + node_pos_top[lr][last_label[lr]]
yt = self.y_title_gap + node_pos_top[lr][last_label[lr]]
self.draw_title(title_x[lr], yt, self.titles[ii + lr], "bottom")

if self.title_side in ("bottom", "both"):
if self.title_loc == "outer":
yt = min(self.voffset) - y_title_gap - y_frame_gap
yt = min(self.voffset) - self.y_title_gap - self.y_frame_gap
elif self.title_loc == "inner":
yt = self.voffset[ii + lr] - y_title_gap
yt = self.voffset[ii + lr] - self.y_title_gap
self.draw_title(title_x[lr], yt, self.titles[ii + lr], "top")

###########################################
Expand Down

0 comments on commit 2ce690c

Please sign in to comment.