diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py index 8134c8c..778e7de 100644 --- a/ausankey/ausankey.py +++ b/ausankey/ausankey.py @@ -285,6 +285,7 @@ def __init__( value_thresh_sum=0, value_thresh_max=0, ): + """Assigns all input arguments to the class as variables with appropriate defaults""" self.ax = ax self.node_width = node_width self.node_gap = node_gap @@ -330,6 +331,8 @@ def __init__( self.value_thresh_max = value_thresh_max def setup(self, data): + """Calculates all parameters needed to plot the graph""" + self.data = data num_col = len(self.data.columns) @@ -397,6 +400,7 @@ def setup(self, data): self.ax.axis("off") def weight_labels(self): + """Calculates sizes of each node, taking into account discontinuities""" self.weight_sum = np.empty(self.num_stages) for ii in range(self.num_stages): @@ -615,32 +619,11 @@ def subplot(self, ii): cc = self.combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d)) for jj in range(len(ys_d) - 1): - self.ax.fill_between( + self.draw_flow( xx[[jj, jj + 1]], ys_d[[jj, jj + 1]], ys_u[[jj, jj + 1]], - color=cc[:, jj], - alpha=self.flow_alpha, - lw=0, - edgecolor="none", - snap=True, - ) - # edges: - self.ax.plot( - xx[[jj, jj + 1]], - ys_d[[jj, jj + 1]], - color=cc[:, jj], - alpha=edge_alpha, - lw=edge_lw, - snap=True, - ) - self.ax.plot( - xx[[jj, jj + 1]], - ys_u[[jj, jj + 1]], - color=cc[:, jj], - alpha=edge_alpha, - lw=edge_lw, - snap=True, + cc[:, jj], ) ha = ["left", "right"] @@ -700,6 +683,7 @@ def subplot(self, ii): self.draw_title(title_x[lr], yt, self.titles[ii + lr], "top") def draw_node(self, x, dx, y, dy, label): + """Draw a single node""" edge_lw = self.node_lw if self.node_edge else 0 self.ax.fill_between( [x, x + dx], @@ -721,7 +705,39 @@ def draw_node(self, x, dx, y, dy, label): snap=True, ) + def draw_flow(self, xx, yd, yu, col): + """Draw a single flow""" + self.ax.fill_between( + xx, + yd, + yu, + color=col, + alpha=self.flow_alpha, + lw=0, + edgecolor="none", + snap=True, + ) + # edges: + if self.flow_edge: + self.ax.plot( + xx, + yd, + color=col, + alpha=self.edge_alpha, + lw=self.flow_lw, + snap=True, + ) + self.ax.plot( + xx, + yu, + color=col, + alpha=self.edge_alpha, + lw=self.flow_lw, + snap=True, + ) + def draw_label(self, x, y, label, ha): + """Place a single label""" self.ax.text( x, y, @@ -737,6 +753,7 @@ def draw_label(self, x, y, label, ha): ) def draw_title(self, x, y, label, va): + """Place a single title""" self.ax.text( x, y, @@ -754,7 +771,7 @@ def draw_title(self, x, y, label, va): ########################################### def sort_nodes(self, lbl, node_sizes): - """creates a sorted list of labels by their summed weights""" + """Creates a sorted list of unique labels into a list""" arr = {} for uniq in lbl.unique(): @@ -771,7 +788,7 @@ def sort_nodes(self, lbl, node_sizes): ########################################### def sort_node_sizes(self, lbl, sorting): - """creates a sorted list of labels by their summed weights""" + """Sorts list of labels and their weights into a dictionary""" if sorting == "top": s = 1