diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 301df56..be180f4 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -47,9 +47,6 @@ class Sankey: data : DataFrame pandas dataframe of labels and weights in alternating columns - aspect : float - vertical extent of the diagram in units of horizontal extent - ax : Axis Matplotlib plot axis to use @@ -160,7 +157,6 @@ class Sankey: def __init__( self, - aspect=4, ax=None, node_width=0.02, node_gap=0.05, @@ -192,7 +188,6 @@ def __init__( sort="bottom", # "top", "bottom", "none" valign="bottom", # "top","center" ): - self.aspect = aspect self.ax = ax self.node_width = node_width self.node_gap = node_gap @@ -227,20 +222,20 @@ def __init__( def setup(self, data): num_col = len(data.columns) data.columns = range(num_col) # force numeric column headings - num_side = int(num_col / 2) # number of stages - self.num_flow = num_side - 1 + self.num_stages = int(num_col / 2) # number of stages + self.num_flow = self.num_stages - 1 # sizes - weight_sum = np.empty(num_side) - num_uniq = np.empty(num_side) - col_hgt = np.empty(num_side) + weight_sum = np.empty(self.num_stages) + num_uniq = np.empty(self.num_stages) + col_hgt = np.empty(self.num_stages) self.node_sizes = {} nodes_uniq = {} - for ii in range(num_side): + for ii in range(self.num_stages): nodes_uniq[ii] = pd.Series(data[2 * ii]).unique() num_uniq[ii] = len(nodes_uniq[ii]) - for ii in range(num_side): + for ii in range(self.num_stages): self.node_sizes[ii] = {} for lbl in nodes_uniq[ii]: if ii == 0: @@ -263,20 +258,20 @@ def setup(self, data): self.node_sizes[ii] = sort_dict(self.node_sizes[ii], self.sort) weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum() - for ii in range(num_side): + for ii in range(self.num_stages): col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * self.node_gap * max(weight_sum) # overall dimensions self.plot_height = max(col_hgt) - self.sub_width = self.plot_height / self.aspect + self.sub_width = self.plot_height self.plot_width = ( - (num_side - 1) * self.sub_width + (self.num_stages - 1) * self.sub_width + 2 * self.sub_width * (self.label_gap + self.label_width) - + num_side * self.sub_width * self.node_width + + self.num_stages * self.sub_width * self.node_width ) # offsets for alignment - self.voffset = np.empty(num_side) + self.voffset = np.empty(self.num_stages) if self.valign == "top": vscale = 1 elif self.valign == "center": @@ -284,11 +279,11 @@ def setup(self, data): else: # bottom, or undefined vscale = 0 - for ii in range(num_side): + for ii in range(self.num_stages): self.voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii]) # labels - label_record = data[range(0, 2 * num_side, 2)].to_records(index=False) + label_record = data[range(0, 2 * self.num_stages, 2)].to_records(index=False) flattened = [item for sublist in label_record for item in sublist] flatcat = pd.Series(flattened).unique() @@ -299,6 +294,7 @@ def setup(self, data): color_palette = cmap(np.linspace(0, 1, len(flatcat))) for i, label in enumerate(flatcat): color_dict_new[label] = color_dict_orig.get(label, color_palette[i]) + check_colors_match_labels(flatcat, color_dict_new) self.color_dict = color_dict_new # initialise plot @@ -380,9 +376,6 @@ def subplot(self, ii, data): sort_nodes(labels_lr[1], self.node_sizes[ii + 1]), ] - # check colours - check_colors_match_labels(labels_lr, self.color_dict) - # Determine sizes of individual subflows nodesize = [{}, {}] for lbl_l in nodes_lr[0]: