diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index d7f3b6b..92940ea 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -266,8 +266,8 @@ def _sankey( Some special-casing is used for plotting/labelling differently for the first and last cases. """ - labelind = 2 * ii - weightind = 2 * ii + 1 + labelind = 2*ii + weightind = 2*ii + 1 left = pd.Series(data[labelind]) right = pd.Series(data[labelind + 2]) @@ -298,7 +298,6 @@ def _sankey( # check colours all_labels = pd.Series([*left, *right]).unique() - missing = [label for label in all_labels if label not in color_dict] if missing: msg = "The color_dict parameter is missing " "values for the following labels: " @@ -306,39 +305,39 @@ def _sankey( raise ValueError(msg) # Determine sizes of individual strips - barsize_left = {} - barsize_right = {} + barsize = [{}, {}] for left_label in left_labels: - barsize_left[left_label] = {} - barsize_right[left_label] = {} + barsize[0][left_label] = {} + barsize[1][left_label] = {} for right_label in right_labels: ind = (left == left_label) & (right == right_label) - barsize_left[left_label][right_label] = left_weight[ind].sum() - barsize_right[left_label][right_label] = right_weight[ind].sum() + barsize[0][left_label][right_label] = left_weight[ind].sum() + barsize[1][left_label][right_label] = right_weight[ind].sum() # Determine positions of left label patches and total widths - left_widths = {} - for i, left_label in enumerate(left_labels): - tmp_dict = {} - tmp_dict["left"] = left_weight[left == left_label].sum() + barpos = [{}, {}] + for i, label in enumerate(left_labels): + barpos[0][label] = {} + barpos[0][label]["total"] = left_weight[left == label].sum() if i == 0: - tmp_dict["bottom"] = voffset[ii] + bot = voffset[ii] else: - tmp_dict["bottom"] = left_widths[left_labels[i - 1]]["top"] + bar_gap * plot_height - tmp_dict["top"] = tmp_dict["bottom"] + tmp_dict["left"] - left_widths[left_label] = tmp_dict + bot = barpos[0][left_labels[i-1]]["top"] + bar_gap*plot_height + + barpos[0][label]["bottom"] = bot + barpos[0][label]["top"] = barpos[0][label]["bottom"] + barpos[0][label]["total"] # Determine positions of right label patches and total widths - right_widths = {} - for i, right_label in enumerate(right_labels): - tmp_dict = {} - tmp_dict["right"] = right_weight[right == right_label].sum() + for i, label in enumerate(right_labels): + barpos[1][label] = {} + barpos[1][label]["total"] = right_weight[right == label].sum() if i == 0: - tmp_dict["bottom"] = voffset[ii + 1] + bot = voffset[ii+1] else: - tmp_dict["bottom"] = right_widths[right_labels[i - 1]]["top"] + bar_gap * plot_height - tmp_dict["top"] = tmp_dict["bottom"] + tmp_dict["right"] - right_widths[right_label] = tmp_dict + bot = barpos[1][right_labels[i-1]]["top"] + bar_gap * plot_height + + barpos[1][label]["bottom"] = bot + barpos[1][label]["top"] = barpos[1][label]["bottom"] + barpos[1][label]["total"] # horizontal extents of flows in each subdiagram x_bar_width = bar_width * sub_width @@ -350,12 +349,11 @@ def _sankey( # Draw bars and their labels if ii == 0: # first time for left_label in left_labels: - lbot = left_widths[left_label]["bottom"] - lll = left_widths[left_label]["left"] + lbot = barpos[0][left_label]["bottom"] + lll = barpos[0][left_label]["total"] ax.fill_between( [x_left - x_bar_width, x_left], - 2 * [lbot], - 2 * [lbot + lll], + lbot, lbot + lll, color=color_dict[left_label], alpha=1, lw=0, @@ -363,18 +361,17 @@ def _sankey( ) ax.text( x_left - x_label_gap - x_bar_width, - lbot + 0.5 * lll, + lbot + lll/2, label_dict.get(left_label, left_label), {"ha": "right", "va": "center"}, fontsize=fontsize, ) for right_label in right_labels: - rbot = right_widths[right_label]["bottom"] - rrr = right_widths[right_label]["right"] + rbot = barpos[1][right_label]["bottom"] + rrr = barpos[1][right_label]["total"] ax.fill_between( [x_right, x_right + x_bar_width], - 2 * [rbot], - [rbot + rrr], + rbot, rbot + rrr, color=color_dict[right_label], alpha=1, lw=0, @@ -383,7 +380,7 @@ def _sankey( if ii < num_flow - 1: # inside labels ax.text( x_right + x_label_gap + x_bar_width, - rbot + 0.5 * rrr, + rbot + rrr/2, label_dict.get(right_label, right_label), {"ha": "left", "va": "center"}, fontsize=fontsize, @@ -391,7 +388,7 @@ def _sankey( if ii == num_flow - 1: # last time ax.text( x_right + x_label_gap + x_bar_width, - rbot + 0.5 * rrr, + rbot + rrr/2, label_dict.get(right_label, right_label), {"ha": "left", "va": "center"}, fontsize=fontsize, @@ -399,52 +396,44 @@ def _sankey( # "titles" if titles is not None: + + y_title_gap = title_gap*plot_height + # leftmost title if ii == 0: - xt = x_left - x_bar_width / 2 + xt = x_left - x_bar_width/2 if title_side in ("top", "both"): - yt = title_gap * plot_height + left_widths[left_label]["top"] + yt = y_title_gap + barpos[0][left_label]["top"] va = "bottom" ax.text( - xt, - yt, - titles[ii], + xt, yt, titles[ii], {"ha": "center", "va": va}, fontsize=fontsize, ) if title_side in ("bottom", "both"): - yt = voffset[ii] - title_gap * plot_height + yt = voffset[ii] - y_title_gap va = "top" - ax.text( - xt, - yt, - titles[ii], + xt, yt, titles[ii], {"ha": "center", "va": va}, fontsize=fontsize, ) # all other titles - xt = x_right + x_bar_width / 2 + xt = x_right + x_bar_width/2 if title_side in ("top", "both"): - yt = title_gap * plot_height + right_widths[right_label]["top"] - + yt = y_title_gap + barpos[1][right_label]["top"] ax.text( - xt, - yt, - titles[ii + 1], + xt, yt, titles[ii+1], {"ha": "center", "va": "bottom"}, fontsize=fontsize, ) if title_side in ("bottom", "both"): - yt = voffset[ii + 1] - title_gap * plot_height - + yt = voffset[ii+1] - y_title_gap ax.text( - xt, - yt, - titles[ii + 1], + xt, yt, titles[ii+1], {"ha": "center", "va": "top"}, fontsize=fontsize, ) @@ -458,18 +447,18 @@ def _sankey( if not any(lind & rind): continue - lbot = left_widths[left_label]["bottom"] - rbot = right_widths[right_label]["bottom"] - lbar = barsize_left[left_label][right_label] - rbar = barsize_right[left_label][right_label] + lbot = barpos[0][left_label]["bottom"] + rbot = barpos[1][right_label]["bottom"] + lbar = barsize[0][left_label][right_label] + rbar = barsize[1][left_label][right_label] ys_d = create_curve(lbot, rbot) ys_u = create_curve(lbot + lbar, rbot + rbar) # Update bottom edges at each label # so next strip starts at the right place - left_widths[left_label]["bottom"] += lbar - right_widths[right_label]["bottom"] += rbar + barpos[0][left_label]["bottom"] += lbar + barpos[1][right_label]["bottom"] += rbar xx = np.linspace(x_left, x_right, len(ys_d)) cc = combine_colours(color_dict[left_label], color_dict[right_label], len(ys_d)) @@ -517,21 +506,19 @@ def check_data_matches_labels(labels, data, side): that the order of labels is still consistent with the labels in the data. """ - if len(labels) > 0: - if isinstance(data, list): - data = set(data) - if isinstance(data, pd.Series): - data = set(data.unique().tolist()) - if isinstance(labels, list): - labels = set(labels) - if labels != data: - msg = "\n" - maxlen = 20 - if len(labels) <= maxlen: - msg = "Labels: " + ",".join(labels) + "\n" - if len(data) < maxlen: - msg += "Data: " + ",".join(data) - raise LabelMismatchError(side, msg) + + if len(labels) == 0: + msg = "Length of labels equals zero?" + raise LabelMismatchError(side, msg) + + if set(labels) != set(data): + msg = "\n" + maxlen = 20 + if len(labels) <= maxlen: + msg += "Labels: " + ",".join(labels) + "\n" + if len(data) < maxlen: + msg += "Data: " + ",".join(data) + raise LabelMismatchError(side, msg) ###########################################