From 0377a46928fa1e0b2122e7f3576e43f7310d3a0c Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Wed, 20 Mar 2024 21:23:07 +1030 Subject: [PATCH] add flow_edge=True option --- ausankey/ausankey.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index c4b390f..6ea1472 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -35,6 +35,7 @@ def sankey( bar_gap=0.05, color_dict=None, colormap="viridis", + flow_edge=True, fontsize=14, frame_side="none", frame_gap=0.1, @@ -84,6 +85,10 @@ def sankey( fontsize : int Font size of labels + flow_edge : bool + Whether to draw an edge to the flows. + Doesn't always look great when there is lots of branching and overlap. + frame_side : str Whether to place a frame (horizontal rule) above or below the plot. Allowed values: `"none"`, `"top"`, `"bottom"`, or `"both"` @@ -252,6 +257,7 @@ def sankey( label_order=label_order, color_dict=color_dict_new, fontsize=fontsize, + flow_edge=flow_edge, frame_gap=frame_gap, label_dict=label_dict or {}, label_width=label_width, @@ -279,6 +285,7 @@ def _sankey( data, color_dict=None, label_order=None, + flow_edge=None, fontsize=None, frame_gap=None, titles=None, @@ -302,6 +309,14 @@ def _sankey( Some special-casing is used for plotting/labelling differently for the first and last cases. """ + + if flow_edge: + edge_alpha = 1 + edge_lw = 1 + else: + edge_alpha = alpha + edge_lw = 0 + labelind = 2 * ii weightind = 2 * ii + 1 @@ -328,7 +343,7 @@ def _sankey( pd.Series(data[weightind + 2]), ] - notnull = labels_lr[0].notnull() & labels_lr[1].notnull() + notnull = labels_lr[0].notnull() & labels_lr[1].notnull() & labels_lr[2].notnull() labels_lr[0] = labels_lr[0][notnull] labels_lr[1] = labels_lr[1][notnull] weights_lr[0] = weights_lr[0][notnull] @@ -399,7 +414,7 @@ def draw_bar(x, dx, y, dy, label): y + dy, color=color_dict[label], alpha=1, - lw=0, + lw=edge_lw, snap=True, ) @@ -496,7 +511,7 @@ def draw_bar(x, dx, y, dy, label): fontsize=fontsize, ) - # Plot strips + # Plot flows for lbl_l in bar_lr[0]: for lbl_r in bar_lr[1]: lind = labels_lr[0] == lbl_l @@ -532,6 +547,22 @@ def draw_bar(x, dx, y, dy, label): edgecolor="none", snap=True, ) + ax.plot( + xx[[jj, jj + 1]], + ys_d[[jj, jj + 1]], + color=cc[:, jj], + alpha=edge_alpha, + lw=edge_lw, + snap=True, + ) + ax.plot( + xx[[jj, jj + 1]], + ys_u[[jj, jj + 1]], + color=cc[:, jj], + alpha=edge_alpha, + lw=edge_lw, + snap=True, + ) ###########################################