diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index ee35fd4..e7d2175 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -314,13 +314,13 @@ def _sankey( # Determine sizes of individual strips barsize = [{}, {}] - for left_label in bar_lr[0]: - barsize[0][left_label] = {} - barsize[1][left_label] = {} - for right_label in bar_lr[1]: - ind = (labels_lr[0] == left_label) & (labels_lr[1] == right_label) - barsize[0][left_label][right_label] = weights_lr[0][ind].sum() - barsize[1][left_label][right_label] = weights_lr[1][ind].sum() + for lbl_l in bar_lr[0]: + barsize[0][lbl_l] = {} + barsize[1][lbl_l] = {} + for lbl_r in bar_lr[1]: + ind = (labels_lr[0] == lbl_l) & (labels_lr[1] == lbl_r) + barsize[0][lbl_l][lbl_r] = weights_lr[0][ind].sum() + barsize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() # Determine positions of left label patches and total widths y_bar_gap = bar_gap * plot_height @@ -398,7 +398,7 @@ def _sankey( if ii == 0: xt = x_left - x_bar_width / 2 if title_side in ("top", "both"): - yt = y_title_gap + barpos[0][left_label]["top"] + yt = y_title_gap + barpos[0][lbl_l]["top"] va = "bottom" ax.text( xt, @@ -422,7 +422,7 @@ def _sankey( # all other titles xt = x_right + x_bar_width / 2 if title_side in ("top", "both"): - yt = y_title_gap + barpos[1][right_label]["top"] + yt = y_title_gap + barpos[1][lbl_r]["top"] ax.text( xt, yt, @@ -442,29 +442,29 @@ def _sankey( ) # Plot strips - for left_label in bar_lr[0]: - for right_label in bar_lr[1]: - lind = labels_lr[0] == left_label - rind = labels_lr[1] == right_label + for lbl_l in bar_lr[0]: + for lbl_r in bar_lr[1]: + lind = labels_lr[0] == lbl_l + rind = labels_lr[1] == lbl_r if not any(lind & rind): continue - lbot = barpos[0][left_label]["bot"] - rbot = barpos[1][right_label]["bot"] - lbar = barsize[0][left_label][right_label] - rbar = barsize[1][left_label][right_label] + lbot = barpos[0][lbl_l]["bot"] + rbot = barpos[1][lbl_r]["bot"] + lbar = barsize[0][lbl_l][lbl_r] + rbar = barsize[1][lbl_l][lbl_r] ys_d = create_curve(lbot, rbot) ys_u = create_curve(lbot + lbar, rbot + rbar) # Update bottom edges at each label # so next strip starts at the right place - barpos[0][left_label]["bot"] += lbar - barpos[1][right_label]["bot"] += rbar + barpos[0][lbl_l]["bot"] += lbar + barpos[1][lbl_r]["bot"] += rbar xx = np.linspace(x_left, x_right, len(ys_d)) - cc = combine_colours(color_dict[left_label], color_dict[right_label], len(ys_d)) + cc = combine_colours(color_dict[lbl_l], color_dict[lbl_r], len(ys_d)) for jj in range(len(ys_d) - 1): ax.fill_between(