From 3496b8450d0051e0adbc0efc428a9ef3fd9c214f Mon Sep 17 00:00:00 2001 From: Will Robertson Date: Fri, 29 Mar 2024 10:05:14 +1030 Subject: [PATCH] use self properly --- ausankey/ausankey.py | 267 ++++++++++++++++++++++--------------------- 1 file changed, 136 insertions(+), 131 deletions(-) diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 8f20bca..fb44803 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -26,12 +26,13 @@ def sankey(data, **kwargs): None (yet) """ - opt = Sankey(**kwargs) - opt = opt.setup(data, opt) + sky = Sankey(**kwargs) + sky.setup(data) + sky.plot_frame() # draw each segment - for ii in range(opt.num_flow): - opt.subplot(ii, data, opt) + for ii in range(sky.num_flow): + sky.subplot(ii, data) class SankeyError(Exception): @@ -223,30 +224,32 @@ def __init__( self.sort = sort self.valign = valign - def setup(self, data, opt): + + def setup(self, data): + num_col = len(data.columns) data.columns = range(num_col) # force numeric column headings num_side = int(num_col / 2) # number of stages - opt.num_flow = num_side - 1 + self.num_flow = num_side - 1 # sizes weight_sum = np.empty(num_side) num_uniq = np.empty(num_side) col_hgt = np.empty(num_side) - opt.node_sizes = {} + self.node_sizes = {} nodes_uniq = {} for ii in range(num_side): nodes_uniq[ii] = pd.Series(data[2 * ii]).unique() num_uniq[ii] = len(nodes_uniq[ii]) for ii in range(num_side): - opt.node_sizes[ii] = {} + self.node_sizes[ii] = {} for lbl in nodes_uniq[ii]: if ii == 0: ind_prev = data[2 * ii + 0] == lbl ind_this = data[2 * ii + 0] == lbl ind_next = data[2 * ii + 2] == lbl - elif ii == opt.num_flow: + elif ii == self.num_flow: ind_prev = data[2 * ii - 2] == lbl ind_this = data[2 * ii + 0] == lbl ind_next = data[2 * ii + 0] == lbl @@ -258,33 +261,33 @@ def setup(self, data, opt): weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum() weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum() weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum() - opt.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) - opt.node_sizes[ii] = sort_dict(opt.node_sizes[ii], opt.sort) - weight_sum[ii] = pd.Series(opt.node_sizes[ii].values()).sum() + self.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt) + self.node_sizes[ii] = sort_dict(self.node_sizes[ii], self.sort) + weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum() for ii in range(num_side): - col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * opt.node_gap * max(weight_sum) + col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * self.node_gap * max(weight_sum) # overall dimensions - opt.plot_height = max(col_hgt) - opt.sub_width = opt.plot_height / opt.aspect - plot_width = ( - (num_side - 1) * opt.sub_width - + 2 * opt.sub_width * (opt.label_gap + opt.label_width) - + num_side * opt.sub_width * opt.node_width + self.plot_height = max(col_hgt) + self.sub_width = self.plot_height / self.aspect + self.plot_width = ( + (num_side - 1) * self.sub_width + + 2 * self.sub_width * (self.label_gap + self.label_width) + + num_side * self.sub_width * self.node_width ) # offsets for alignment - opt.voffset = np.empty(num_side) - if opt.valign == "top": + self.voffset = np.empty(num_side) + if self.valign == "top": vscale = 1 - elif opt.valign == "center": + elif self.valign == "center": vscale = 0.5 else: # bottom, or undefined vscale = 0 for ii in range(num_side): - opt.voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii]) + self.voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii]) # labels label_record = data[range(0, 2 * num_side, 2)].to_records(index=False) @@ -292,58 +295,60 @@ def setup(self, data, opt): flatcat = pd.Series(flattened).unique() # If no color_dict given, make one - color_dict_orig = opt.color_dict or {} + color_dict_orig = self.color_dict or {} color_dict_new = {} - cmap = plt.cm.get_cmap(opt.colormap) + cmap = plt.cm.get_cmap(self.colormap) color_palette = cmap(np.linspace(0, 1, len(flatcat))) for i, label in enumerate(flatcat): color_dict_new[label] = color_dict_orig.get(label, color_palette[i]) - opt.color_dict = color_dict_new + self.color_dict = color_dict_new # initialise plot - opt.ax = opt.ax or plt.gca() - opt.ax.axis("off") + self.ax = self.ax or plt.gca() + self.ax.axis("off") + + if self.label_loc is None: + self.label_loc = ["left", "none", "right"] + if self.node_edge is None: + self.node_edge = False + if self.flow_edge is None: + self.flow_edge = False + if self.title_font is None: + self.title_font = {"fontweight": "bold"} + if self.label_dict is None: + self.label_dict = {} + if self.label_font is None: + self.label_font = {} - # frame on top/bottom edge - frame_top = opt.frame_side in ("top", "both") - frame_bot = opt.frame_side in ("bottom", "both") - frame_color = opt.frame_color or [0, 0, 0, 1] + def plot_frame(self): + """Plot frame on top/bottom edges""" - y_frame_gap = opt.frame_gap * opt.plot_height + frame_top = self.frame_side in ("top", "both") + frame_bot = self.frame_side in ("bottom", "both") + + frame_color = self.frame_color or [0, 0, 0, 1] + + y_frame_gap = self.frame_gap * self.plot_height col = frame_color if frame_top else [1, 1, 1, 0] - opt.ax.plot( - [0, plot_width], - min(opt.voffset) + (opt.plot_height) + y_frame_gap + [0, 0], + self.ax.plot( + [0, self.plot_width], + min(self.voffset) + (self.plot_height) + y_frame_gap + [0, 0], color=col, - lw=opt.frame_lw, + lw=self.frame_lw, ) col = frame_color if frame_bot else [1, 1, 1, 0] - opt.ax.plot( - [0, plot_width], - min(opt.voffset) - y_frame_gap + [0, 0], + self.ax.plot( + [0, self.plot_width], + min(self.voffset) - y_frame_gap + [0, 0], color=col, - lw=opt.frame_lw, + lw=self.frame_lw, ) - if opt.label_loc is None: - opt.label_loc = ["left", "none", "right"] - if opt.node_edge is None: - opt.node_edge = False - if opt.flow_edge is None: - opt.flow_edge = False - if opt.title_font is None: - opt.title_font = {"fontweight": "bold"} - if opt.label_dict is None: - opt.label_dict = {} - if opt.label_font is None: - opt.label_font = {} - - return opt - - def subplot(self, ii, data, opt): + + def subplot(self, ii, data): """Subroutine for plotting horizontal sections of the Sankey plot Some special-casing is used for plotting/labelling differently @@ -353,7 +358,7 @@ def subplot(self, ii, data, opt): labelind = 2 * ii weightind = 2 * ii + 1 - if ii < opt.num_flow - 1: + if ii < self.num_flow - 1: labels_lr = [ pd.Series(data[labelind]), pd.Series(data[labelind + 2]), @@ -377,12 +382,12 @@ def subplot(self, ii, data, opt): ] nodes_lr = [ - sort_nodes(labels_lr[0], opt.node_sizes[ii]), - sort_nodes(labels_lr[1], opt.node_sizes[ii + 1]), + sort_nodes(labels_lr[0], self.node_sizes[ii]), + sort_nodes(labels_lr[1], self.node_sizes[ii + 1]), ] # check colours - check_colors_match_labels(labels_lr, opt.color_dict) + check_colors_match_labels(labels_lr, self.color_dict) # Determine sizes of individual subflows nodesize = [{}, {}] @@ -395,11 +400,11 @@ def subplot(self, ii, data, opt): nodesize[1][lbl_l][lbl_r] = weights_lr[1][ind].sum() # Determine vertical positions of nodes - y_node_gap = opt.node_gap * opt.plot_height + y_node_gap = self.node_gap * self.plot_height - if opt.valign == "top": + if self.valign == "top": vscale = 1 - elif opt.valign == "center": + elif self.valign == "center": vscale = 0.5 else: # bottom, or undefined vscale = 0 @@ -410,39 +415,39 @@ def subplot(self, ii, data, opt): for lr in [0, 1]: for i, label in enumerate(nodes_lr[lr]): - node_height = opt.node_sizes[ii + lr][label] + node_height = self.node_sizes[ii + lr][label] this_side_height = weights_lr[lr][labels_lr[lr] == label].sum() node_voffset[lr][label] = vscale * (node_height - this_side_height) next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0 - node_pos_bot[lr][label] = opt.voffset[ii + lr] if i == 0 else next_bot + node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height # horizontal positions of nodes - x_node_width = opt.node_width * opt.sub_width - x_label_width = opt.label_width * opt.sub_width - x_label_gap = opt.label_gap * opt.sub_width - x_left = x_node_width + x_label_gap + x_label_width + ii * (opt.sub_width + x_node_width) - x_lr = [x_left, x_left + opt.sub_width] + x_node_width = self.node_width * self.sub_width + x_label_width = self.label_width * self.sub_width + x_label_gap = self.label_gap * self.sub_width + x_left = x_node_width + x_label_gap + x_label_width + ii * (self.sub_width + x_node_width) + x_lr = [x_left, x_left + self.sub_width] # Draw nodes def draw_node(x, dx, y, dy, label): - edge_lw = opt.node_lw if opt.node_edge else 0 - opt.ax.fill_between( + edge_lw = self.node_lw if self.node_edge else 0 + self.ax.fill_between( [x, x + dx], y, y + dy, - facecolor=opt.color_dict[label], - alpha=opt.node_alpha, + facecolor=self.color_dict[label], + alpha=self.node_alpha, lw=edge_lw, snap=True, ) - if opt.node_edge: - opt.ax.fill_between( + if self.node_edge: + self.ax.fill_between( [x, x + dx], y, y + dy, - edgecolor=opt.color_dict[label], + edgecolor=self.color_dict[label], facecolor="none", lw=edge_lw, snap=True, @@ -454,24 +459,24 @@ def draw_node(x, dx, y, dy, label): x_lr[lr] - x_node_width * (1 - lr), x_node_width, node_pos_bot[lr][label], - opt.node_sizes[ii + lr][label], + self.node_sizes[ii + lr][label], label, ) # Draw node labels def draw_label(x, y, label, ha): - opt.ax.text( + self.ax.text( x, y, - opt.label_dict.get(label, label), + self.label_dict.get(label, label), { "ha": ha, "va": "center", - "fontfamily": opt.fontfamily, - "fontsize": opt.fontsize, - "color": opt.fontcolor, - **opt.label_font, + "fontfamily": self.fontfamily, + "fontsize": self.fontsize, + "color": self.fontcolor, + **self.label_font, }, ) @@ -479,60 +484,60 @@ def draw_label(x, y, label, ha): # first row of labels lr = 0 - if ii == 0 and opt.label_loc[0] != "none": - if opt.label_loc[0] in ("left"): + if ii == 0 and self.label_loc[0] != "none": + if self.label_loc[0] in ("left"): xx = x_lr[lr] - x_label_gap - x_node_width - elif opt.label_loc[0] in ("right"): + elif self.label_loc[0] in ("right"): xx = x_lr[lr] + x_label_gap - elif opt.label_loc[0] in ("center"): + elif self.label_loc[0] in ("center"): xx = x_lr[lr] - x_node_width / 2 for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha_dict[opt.label_loc[0]]) + yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha_dict[self.label_loc[0]]) # inside labels, left lr = 1 - if ii < opt.num_flow - 1 and opt.label_loc[1] in ("left", "both"): + if ii < self.num_flow - 1 and self.label_loc[1] in ("left", "both"): xx = x_lr[lr] - x_label_gap ha = "right" for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 draw_label(xx, yy, label, ha) # inside labels, center - if ii < opt.num_flow - 1 and opt.label_loc[1] in ("center"): + if ii < self.num_flow - 1 and self.label_loc[1] in ("center"): xx = x_lr[lr] + x_node_width / 2 ha = "center" for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 draw_label(xx, yy, label, ha) # inside labels, right - if ii < opt.num_flow - 1 and opt.label_loc[1] in ("right", "both"): + if ii < self.num_flow - 1 and self.label_loc[1] in ("right", "both"): xx = x_lr[lr] + x_label_gap + x_node_width ha = "left" for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 + yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 draw_label(xx, yy, label, ha) # last row of labels - if ii == opt.num_flow - 1 and opt.label_loc[2] != "none": - if opt.label_loc[2] in ("left"): + if ii == self.num_flow - 1 and self.label_loc[2] != "none": + if self.label_loc[2] in ("left"): xx = x_lr[lr] - x_label_gap - elif opt.label_loc[2] in ("right"): + elif self.label_loc[2] in ("right"): xx = x_lr[lr] + x_label_gap + x_node_width - elif opt.label_loc[2] in ("center"): + elif self.label_loc[2] in ("center"): xx = x_lr[lr] + x_node_width / 2 for label in nodes_lr[lr]: - yy = node_pos_bot[lr][label] + opt.node_sizes[ii + lr][label] / 2 - draw_label(xx, yy, label, ha_dict[opt.label_loc[2]]) + yy = node_pos_bot[lr][label] + self.node_sizes[ii + lr][label] / 2 + draw_label(xx, yy, label, ha_dict[self.label_loc[2]]) # Plot flows - if opt.flow_edge: - edge_lw = opt.flow_lw + if self.flow_edge: + edge_lw = self.flow_lw edge_alpha = 1 else: - edge_alpha = opt.flow_alpha + edge_alpha = self.flow_alpha edge_lw = 0 for lbl_l in nodes_lr[0]: @@ -556,21 +561,21 @@ def draw_label(x, y, label, ha): node_pos_bot[1][lbl_r] += rlen xx = np.linspace(x_lr[0], x_lr[1], len(ys_d)) - cc = combine_colours(opt.color_dict[lbl_l], opt.color_dict[lbl_r], len(ys_d)) + cc = combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d)) for jj in range(len(ys_d) - 1): - opt.ax.fill_between( + self.ax.fill_between( xx[[jj, jj + 1]], ys_d[[jj, jj + 1]], ys_u[[jj, jj + 1]], color=cc[:, jj], - alpha=opt.flow_alpha, + alpha=self.flow_alpha, lw=0, edgecolor="none", snap=True, ) # edges: - opt.ax.plot( + self.ax.plot( xx[[jj, jj + 1]], ys_d[[jj, jj + 1]], color=cc[:, jj], @@ -578,7 +583,7 @@ def draw_label(x, y, label, ha): lw=edge_lw, snap=True, ) - opt.ax.plot( + self.ax.plot( xx[[jj, jj + 1]], ys_u[[jj, jj + 1]], color=cc[:, jj], @@ -588,11 +593,11 @@ def draw_label(x, y, label, ha): ) # Place "titles" - if opt.titles is not None: + if self.titles is not None: last_label = [lbl_l, lbl_r] - y_title_gap = opt.title_gap * opt.plot_height - y_frame_gap = opt.frame_gap * opt.plot_height + y_title_gap = self.title_gap * self.plot_height + y_frame_gap = self.frame_gap * self.plot_height title_x = [ x_lr[0] - x_node_width / 2, @@ -600,17 +605,17 @@ def draw_label(x, y, label, ha): ] def draw_title(x, y, label, va): - opt.ax.text( + self.ax.text( x, y, label, { "ha": "center", "va": va, - "fontfamily": opt.fontfamily, - "fontsize": opt.fontsize, - "color": opt.fontcolor, - **opt.title_font, + "fontfamily": self.fontfamily, + "fontsize": self.fontsize, + "color": self.fontcolor, + **self.title_font, }, ) @@ -618,19 +623,19 @@ def draw_title(x, y, label, va): title_lr = [0, 1] if ii == 0 else [1] for lr in title_lr: - if opt.title_side in ("top", "both"): - if opt.title_loc == "outer": - yt = min(opt.voffset) + y_title_gap + y_frame_gap + opt.plot_height - elif opt.title_loc == "inner": + if self.title_side in ("top", "both"): + if self.title_loc == "outer": + yt = min(self.voffset) + y_title_gap + y_frame_gap + self.plot_height + elif self.title_loc == "inner": yt = y_title_gap + node_pos_top[lr][last_label[lr]] - draw_title(title_x[lr], yt, opt.titles[ii + lr], "bottom") - - if opt.title_side in ("bottom", "both"): - if opt.title_loc == "outer": - yt = min(opt.voffset) - y_title_gap - y_frame_gap - elif opt.title_loc == "inner": - yt = opt.voffset[ii + lr] - y_title_gap - draw_title(title_x[lr], yt, opt.titles[ii + lr], "top") + draw_title(title_x[lr], yt, self.titles[ii + lr], "bottom") + + if self.title_side in ("bottom", "both"): + if self.title_loc == "outer": + yt = min(self.voffset) - y_title_gap - y_frame_gap + elif self.title_loc == "inner": + yt = self.voffset[ii + lr] - y_title_gap + draw_title(title_x[lr], yt, self.titles[ii + lr], "top") ###########################################