Skip to content

Commit

Permalink
enough tinkering for now
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 25, 2024
1 parent 543a44c commit 45c70ed
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,21 +481,20 @@ def draw_label(x, y, label, ha):
},
)

ha_dict = {"left": "right", "right": "left", "center": "center"}

# first row of labels
lr = 0
if ii == 0 and label_loc[0] != "none":
if label_loc[0] in ("left"):
xx = x_lr[lr] - x_label_gap - x_node_width
ha = "right"
elif label_loc[0] in ("right"):
xx = x_lr[lr] + x_label_gap
ha = "left"
elif label_loc[0] in ("center"):
xx = x_lr[lr] - x_node_width / 2
ha = "center"
for label in nodes_lr[lr]:
yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha)
draw_label(xx, yy, label, ha_dict[label_loc[0]])

# inside labels, left
lr = 1
Expand Down Expand Up @@ -523,7 +522,6 @@ def draw_label(x, y, label, ha):
draw_label(xx, yy, label, ha)

# last row of labels
ha_dict = {"left": "right", "right": "left", "center": "center"}
if ii == num_flow - 1 and label_loc[2] != "none":
if label_loc[2] in ("left"):
xx = x_lr[lr] - x_label_gap
Expand All @@ -532,21 +530,17 @@ def draw_label(x, y, label, ha):
elif label_loc[2] in ("center"):
xx = x_lr[lr] + x_node_width / 2
for label in nodes_lr[lr]:
draw_label(
xx,
node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2,
label,
ha_dict[label_loc[2]],
)
yy = node_pos_bot[lr][label] + node_sizes[ii + lr][label] / 2
draw_label(xx, yy, label, ha_dict[label_loc[2]])

# Plot flows
if flow_edge:
edge_lw = flow_lw
edge_alpha = 1
else:
edge_alpha = flow_alpha
edge_lw = 0

# Plot flows
for lbl_l in nodes_lr[0]:
for lbl_r in nodes_lr[1]:
lind = labels_lr[0] == lbl_l
Expand All @@ -556,16 +550,16 @@ def draw_label(x, y, label, ha):

lbot = node_voffset[0][lbl_l] + node_pos_bot[0][lbl_l]
rbot = node_voffset[1][lbl_r] + node_pos_bot[1][lbl_r]
lnode = nodesize[0][lbl_l][lbl_r]
rnode = nodesize[1][lbl_l][lbl_r]
llen = nodesize[0][lbl_l][lbl_r]
rlen = nodesize[1][lbl_l][lbl_r]

ys_d = create_curve(lbot, rbot)
ys_u = create_curve(lbot + lnode, rbot + rnode)
ys_u = create_curve(lbot + llen, rbot + rlen)

# Update bottom edges at each label
# so next strip starts at the right place
node_pos_bot[0][lbl_l] += lnode
node_pos_bot[1][lbl_r] += rnode
node_pos_bot[0][lbl_l] += llen
node_pos_bot[1][lbl_r] += rlen

xx = np.linspace(x_lr[0], x_lr[1], len(ys_d))
cc = combine_colours(color_dict[lbl_l], color_dict[lbl_r], len(ys_d))
Expand All @@ -581,6 +575,7 @@ def draw_label(x, y, label, ha):
edgecolor="none",
snap=True,
)
# edges:
ax.plot(
xx[[jj, jj + 1]],
ys_d[[jj, jj + 1]],
Expand Down

0 comments on commit 45c70ed

Please sign in to comment.