Skip to content

Commit

Permalink
add flow_edge=True option
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 20, 2024
1 parent b42ff45 commit 0377a46
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def sankey(
bar_gap=0.05,
color_dict=None,
colormap="viridis",
flow_edge=True,

Check failure on line 38 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FBT002)

ausankey/ausankey.py:38:5: FBT002 Boolean default positional argument in function definition
fontsize=14,
frame_side="none",
frame_gap=0.1,
Expand Down Expand Up @@ -84,6 +85,10 @@ def sankey(
fontsize : int
Font size of labels
flow_edge : bool
Whether to draw an edge to the flows.
Doesn't always look great when there is lots of branching and overlap.
frame_side : str
Whether to place a frame (horizontal rule) above or below the plot.
Allowed values: `"none"`, `"top"`, `"bottom"`, or `"both"`
Expand Down Expand Up @@ -252,6 +257,7 @@ def sankey(
label_order=label_order,
color_dict=color_dict_new,
fontsize=fontsize,
flow_edge=flow_edge,
frame_gap=frame_gap,
label_dict=label_dict or {},
label_width=label_width,
Expand Down Expand Up @@ -279,6 +285,7 @@ def _sankey(
data,
color_dict=None,
label_order=None,
flow_edge=None,
fontsize=None,
frame_gap=None,
titles=None,
Expand All @@ -302,6 +309,14 @@ def _sankey(
Some special-casing is used for plotting/labelling differently
for the first and last cases.
"""

if flow_edge:
edge_alpha = 1
edge_lw = 1
else:
edge_alpha = alpha
edge_lw = 0

labelind = 2 * ii
weightind = 2 * ii + 1

Expand All @@ -328,7 +343,7 @@ def _sankey(
pd.Series(data[weightind + 2]),
]

notnull = labels_lr[0].notnull() & labels_lr[1].notnull()
notnull = labels_lr[0].notnull() & labels_lr[1].notnull() & labels_lr[2].notnull()
labels_lr[0] = labels_lr[0][notnull]
labels_lr[1] = labels_lr[1][notnull]
weights_lr[0] = weights_lr[0][notnull]
Expand Down Expand Up @@ -399,7 +414,7 @@ def draw_bar(x, dx, y, dy, label):
y + dy,
color=color_dict[label],
alpha=1,
lw=0,
lw=edge_lw,
snap=True,
)

Expand Down Expand Up @@ -496,7 +511,7 @@ def draw_bar(x, dx, y, dy, label):
fontsize=fontsize,
)

# Plot strips
# Plot flows
for lbl_l in bar_lr[0]:
for lbl_r in bar_lr[1]:
lind = labels_lr[0] == lbl_l
Expand Down Expand Up @@ -532,6 +547,22 @@ def draw_bar(x, dx, y, dy, label):
edgecolor="none",
snap=True,
)
ax.plot(
xx[[jj, jj + 1]],
ys_d[[jj, jj + 1]],
color=cc[:, jj],
alpha=edge_alpha,
lw=edge_lw,
snap=True,
)
ax.plot(
xx[[jj, jj + 1]],
ys_u[[jj, jj + 1]],
color=cc[:, jj],
alpha=edge_alpha,
lw=edge_lw,
snap=True,
)


###########################################
Expand Down

0 comments on commit 0377a46

Please sign in to comment.