From 0eb0bfaa3a68c9e9db0478a550063c856fa70886 Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Wed, 20 Mar 2024 19:57:44 +1030 Subject: [PATCH] a baby step --- ausankey/ausankey.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 5d3fa3e..441f261 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -305,14 +305,28 @@ def _sankey( labelind = 2 * ii weightind = 2 * ii + 1 - labels_lr = [ - pd.Series(data[labelind]), - pd.Series(data[labelind + 2]), - ] - weights_lr = [ - pd.Series(data[weightind]), - pd.Series(data[weightind + 2]), - ] + if ii < num_flow-1: + labels_lr = [ + pd.Series(data[labelind]), + pd.Series(data[labelind + 2]), + pd.Series(data[labelind + 4]), + ] + weights_lr = [ + pd.Series(data[weightind]), + pd.Series(data[weightind + 2]), + pd.Series(data[weightind + 4]), + ] + else: + labels_lr = [ + pd.Series(data[labelind]), + pd.Series(data[labelind + 2]), + pd.Series(data[labelind + 2]), + ] + weights_lr = [ + pd.Series(data[weightind]), + pd.Series(data[weightind + 2]), + pd.Series(data[weightind + 2]), + ] notnull = labels_lr[0].notnull() & labels_lr[1].notnull() labels_lr[0] = labels_lr[0][notnull] @@ -348,7 +362,7 @@ def _sankey( msg += "{}".format(", ".join(missing)) raise ValueError(msg) - # Determine sizes of individual strips + # Determine sizes of individual subflows barsize = [{}, {}] for lbl_l in bar_lr[0]: barsize[0][lbl_l] = {} @@ -358,7 +372,7 @@ def _sankey( barsize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum() barsize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() - # Determine positions of label patches and total widths + # Determine vertical positions of nodes y_bar_gap = bar_gap * plot_height barpos = [{}, {}] @@ -369,7 +383,7 @@ def _sankey( barpos[lr][label]["bot"] = voffset[ii + lr] if i == 0 else barpos[lr][bar_lr[lr][i - 1]]["top"] + y_bar_gap barpos[lr][label]["top"] = barpos[lr][label]["bot"] + barpos[lr][label]["tot"] - # horizontal extents of flows in each subdiagram + # horizontal positions of nodes x_bar_width = bar_width * sub_width x_label_width = label_width * sub_width x_label_gap = label_gap * sub_width