diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 885f022..197ad2c 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -603,8 +603,8 @@ def subplot(self, ii): bot_lr = [lbot, rbot] len_lr = [llen, rlen] - ys_d = create_curve(lbot, rbot) - ys_u = create_curve(lbot + llen, rbot + rlen) + ys_d = self.create_curve(lbot, rbot) + ys_u = self.create_curve(lbot + llen, rbot + rlen) # Update bottom edges at each label # so next strip starts at the right place @@ -612,7 +612,7 @@ def subplot(self, ii): node_pos_bot[1][lbl_r] += rlen xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) - cc = combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d)) + cc = self.combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d)) for jj in range(len(ys_d) - 1): self.ax.fill_between( @@ -798,82 +798,66 @@ def sort_node_sizes(self, lbl, sorting): return sorted_labels -########################################### - - -def check_colors_match_labels(labels_lr, color_dict): - """Check that all labels in labels_lr are in color_dict""" - - all_labels = pd.Series([*labels_lr[0], *labels_lr[1]]).unique() - - missing = [label for label in all_labels if label not in color_dict] - - if missing: - msg = "The color_dict parameter is missing " "values for the following labels: " - msg += "{}".format(", ".join(missing)) - raise ValueError(msg) - - -########################################### - - -def create_curve(lpoint, rpoint): - """Create array of y values for each strip""" - - num_div = 20 - num_arr = 50 - - # half at left value, half at right, convolve - - ys = np.array(num_arr * [lpoint] + num_arr * [rpoint]) - - ys = np.convolve(ys, 1 / num_div * np.ones(num_div), mode="valid") - - return np.convolve(ys, 1 / num_div * np.ones(num_div), mode="valid") - - -########################################### - - -def combine_colours(c1, c2, num_col): - """Creates N colours needed to produce a gradient - - Parameters - ---------- - - c1 : col - First (left) colour. Can be a colour string `"#rrbbgg"` or a colour list `[r, b, g, a]` - - c1 : col - Second (right) colour. As above. - - num_col : int - The number of colours N to create in the array. - - Returns - ------- - - color_array : np.array - 4xN array of numerical colours - """ - color_array_len = 4 - # if not [r,g,b,a] assume a hex string like "#rrggbb": - - if len(c1) != color_array_len: - r1 = int(c1[1:3], 16) / 255 - g1 = int(c1[3:5], 16) / 255 - b1 = int(c1[5:7], 16) / 255 - c1 = [r1, g1, b1, 1] - - if len(c2) != color_array_len: - r2 = int(c2[1:3], 16) / 255 - g2 = int(c2[3:5], 16) / 255 - b2 = int(c2[5:7], 16) / 255 - c2 = [r2, g2, b2, 1] - - rr = np.linspace(c1[0], c2[0], num_col) - gg = np.linspace(c1[1], c2[1], num_col) - bb = np.linspace(c1[2], c2[2], num_col) - aa = np.linspace(c1[3], c2[3], num_col) - - return np.array([rr, gg, bb, aa]) + ########################################### + + + def create_curve(self, lpoint, rpoint): + """Create array of y values for each strip""" + + num_div = 20 + num_arr = 50 + + # half at left value, half at right, convolve + + ys = np.array(num_arr * [lpoint] + num_arr * [rpoint]) + + ys = np.convolve(ys, 1 / num_div * np.ones(num_div), mode="valid") + + return np.convolve(ys, 1 / num_div * np.ones(num_div), mode="valid") + + + ########################################### + + + def combine_colours(self, c1, c2, num_col): + """Creates N colours needed to produce a gradient + + Parameters + ---------- + + c1 : col + First (left) colour. Can be a colour string `"#rrbbgg"` or a colour list `[r, b, g, a]` + + c1 : col + Second (right) colour. As above. + + num_col : int + The number of colours N to create in the array. + + Returns + ------- + + color_array : np.array + 4xN array of numerical colours + """ + color_array_len = 4 + # if not [r,g,b,a] assume a hex string like "#rrggbb": + + if len(c1) != color_array_len: + r1 = int(c1[1:3], 16) / 255 + g1 = int(c1[3:5], 16) / 255 + b1 = int(c1[5:7], 16) / 255 + c1 = [r1, g1, b1, 1] + + if len(c2) != color_array_len: + r2 = int(c2[1:3], 16) / 255 + g2 = int(c2[3:5], 16) / 255 + b2 = int(c2[5:7], 16) / 255 + c2 = [r2, g2, b2, 1] + + rr = np.linspace(c1[0], c2[0], num_col) + gg = np.linspace(c1[1], c2[1], num_col) + bb = np.linspace(c1[2], c2[2], num_col) + aa = np.linspace(c1[3], c2[3], num_col) + + return np.array([rr, gg, bb, aa]) diff --git a/docs/reference.md b/docs/reference.md index dfcb99d..72cb155 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -1,20 +1,8 @@ -# Reference - -## The `Sankey` class - -::: ausankey.Sankey - -## The `sankey` user function +# The `sankey` user function ::: ausankey.sankey -## Internal/private functions +# The `Sankey` class -### create_curve - -::: ausankey.create_curve - -### combine_colours - -::: ausankey.combine_colours +::: ausankey.Sankey