Skip to content

Commit

Permalink
add value labels
Browse files Browse the repository at this point in the history
Not quite working yet! Unsure how to avoid overlaps with defaults while
providing flexibility
  • Loading branch information
wspr committed Apr 1, 2024
1 parent 8cf0569 commit d322875
Showing 1 changed file with 65 additions and 28 deletions.
93 changes: 65 additions & 28 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def __init__(
title_font=None,
sort="bottom", # "top", "bottom", "none"
valign="bottom", # "top","center"
value_format = ".0f",
value_gap = None,
value_font = None,
value_loc=("right", "left", "left"),
):
self.ax = ax
self.node_width = node_width
Expand Down Expand Up @@ -273,35 +277,11 @@ def __init__(
self.title_side = title_side
self.sort = sort
self.valign = valign
self.value_format = value_format
self.value_gap = label_gap if value_gap is None else value_gap
self.value_font = value_font or {}
self.value_loc = value_loc

def weight_labels(self):
self.weight_sum = np.empty(self.num_stages)

for ii in range(self.num_stages):
self.nodes_uniq[ii] = pd.Series(self.data[2 * ii]).unique()

for ii in range(self.num_stages):
self.node_sizes[ii] = {}
for lbl in self.nodes_uniq[ii]:
if ii == 0:
ind_prev = self.data[2 * ii - 0] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 2] == lbl
elif ii == self.num_flow:
ind_prev = self.data[2 * ii - 2] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 0] == lbl
else:
ind_prev = self.data[2 * ii - 2] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 2] == lbl
weight_cont = self.data[2 * ii + 1][ind_this & ind_prev & ind_next].sum()
weight_only = self.data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum()
weight_stop = self.data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
weight_strt = self.data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
self.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)

self.weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum()

def setup(self, data):
self.data = data
Expand Down Expand Up @@ -368,6 +348,37 @@ def setup(self, data):
self.ax = self.ax or plt.gca()
self.ax.axis("off")


def weight_labels(self):
self.weight_sum = np.empty(self.num_stages)

for ii in range(self.num_stages):
self.nodes_uniq[ii] = pd.Series(self.data[2 * ii]).unique()

for ii in range(self.num_stages):
self.node_sizes[ii] = {}
for lbl in self.nodes_uniq[ii]:
if ii == 0:
ind_prev = self.data[2 * ii - 0] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 2] == lbl
elif ii == self.num_flow:
ind_prev = self.data[2 * ii - 2] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 0] == lbl
else:
ind_prev = self.data[2 * ii - 2] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 2] == lbl
weight_cont = self.data[2 * ii + 1][ind_this & ind_prev & ind_next].sum()
weight_only = self.data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum()
weight_stop = self.data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
weight_strt = self.data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
self.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)

self.weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum()


def plot_frame(self):
"""Plot frame on top/bottom edges"""

Expand All @@ -394,6 +405,7 @@ def plot_frame(self):
lw=self.frame_lw,
)


def subplot(self, ii):
"""Subroutine for plotting horizontal sections of the Sankey plot
Expand Down Expand Up @@ -451,6 +463,7 @@ def subplot(self, ii):
x_node_width = self.node_width * self.sub_width
x_label_width = self.label_width * self.sub_width
x_label_gap = self.label_gap * self.sub_width
x_value_gap = self.value_gap * self.sub_width
x_left = x_node_width + x_label_gap + x_label_width + ii * (self.sub_width + x_node_width)
x_lr = [x_left, x_left + self.sub_width]

Expand Down Expand Up @@ -579,6 +592,8 @@ def draw_label(x, y, label, ha):
rbot = node_voffset[1][lbl_r] + node_pos_bot[1][lbl_r]
llen = nodesize[0][lbl_l][lbl_r]
rlen = nodesize[1][lbl_l][lbl_r]
bot_lr = [lbot, rbot]
len_lr = [llen, rlen]

ys_d = create_curve(lbot, rbot)
ys_u = create_curve(lbot + llen, rbot + rlen)
Expand Down Expand Up @@ -620,6 +635,28 @@ def draw_label(x, y, label, ha):
snap=True,
)

ha = ["left", "right"]
sides = []
if ii == 0:
ind = 0
elif ii == self.num_flow - 1:
ind = 2
else:
ind = 1
if self.value_loc[ind] in ("left","both"):
sides.append(0)
if self.value_loc[ind] in ("right","both"):
sides.append(1)
for lr in sides:
self.ax.text(
x_lr[lr] + (1 - 2 * lr) * x_value_gap,
bot_lr[lr] + len_lr[lr] / 2,
f"{format(len_lr[lr],self.value_format)}",
ha=ha[lr],
va="center",
**self.value_font,
)

# Place "titles"
if self.titles is not None:
last_label = [lbl_l, lbl_r]
Expand Down

0 comments on commit d322875

Please sign in to comment.