Skip to content

Commit

Permalink
add other_dict option
Browse files Browse the repository at this point in the history
I don't love this interface, it might change
  • Loading branch information
wspr committed Mar 30, 2024
1 parent 85f866b commit f8b7867
Showing 1 changed file with 82 additions and 39 deletions.
121 changes: 82 additions & 39 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sankey(data, **kwargs):

# draw each segment
for ii in range(sky.num_flow):
sky.subplot(ii, data)
sky.subplot(ii)


class SankeyError(Exception):
Expand Down Expand Up @@ -124,6 +124,7 @@ class Sankey:
label_loc : [str1, str2, str3]
Position to place labels next to the nodes.
* `str1`: position of first labels (`"left"`, `"right"`, `"center"`, or `"none"`)
* `str2`: position of middle labels (`"left"`, `"right"`, `"both"`, `"center"`, or `"none"`)
* `str3`: position of last labels (`"left"`, `"right"`, `"center"`, or `"none"`)
Expand All @@ -133,6 +134,19 @@ class Sankey:
appear in the previous stage. This minimises chart clutter but might
be confusing in cases, hence defaulting to True.
other_dict : dict
Sets thresholds to recategorise nodes that are below a certain value.
Up to three dictionary keys can be set:
* `"val": v` — set node to other if it is less than `v`
* `"sum": s` — set node to other if it is less than `s` fraction of the summed total of all nodes in the current stage

Check failure on line 142 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

ausankey/ausankey.py:142:121: E501 Line too long (126 > 120)
* `"max": m` — set node to other if is is less than `m` fraction of the maximum node in the current stage
If any of these criteria are met the reclassification will occur.
other_name : str
The string used to rename nodes to if they are classified as “other”.
sort : int
Sorting routine to use for the data.
* `"top"`: data is sorted with largest entries on top
Expand Down Expand Up @@ -186,6 +200,8 @@ def __init__(
flow_lw=1,
node_lw=1,
frame_lw=1,
other_dict=None,
other_name="Other",
titles=None,
title_gap=0.05,
title_side="top", # "bottom", "both"
Expand Down Expand Up @@ -218,6 +234,8 @@ def __init__(
self.flow_lw = flow_lw
self.node_lw = node_lw
self.frame_lw = frame_lw
self.other_name = other_name
self.other_dict = other_dict or {}
self.titles = titles
self.title_font = title_font or {"fontweight": "bold"}
self.title_gap = title_gap
Expand All @@ -226,47 +244,71 @@ def __init__(
self.sort = sort
self.valign = valign

def setup(self, data):
num_col = len(data.columns)
data.columns = range(num_col) # force numeric column headings
self.num_stages = int(num_col / 2) # number of stages
self.num_flow = self.num_stages - 1

# sizes
weight_sum = np.empty(self.num_stages)
num_uniq = np.empty(self.num_stages)
col_hgt = np.empty(self.num_stages)
self.node_sizes = {}
nodes_uniq = {}
def weight_labels(self):

self.weight_sum = np.empty(self.num_stages)

for ii in range(self.num_stages):
nodes_uniq[ii] = pd.Series(data[2 * ii]).unique()
num_uniq[ii] = len(nodes_uniq[ii])
self.nodes_uniq[ii] = pd.Series(self.data[2 * ii]).unique()

for ii in range(self.num_stages):
self.node_sizes[ii] = {}
for lbl in nodes_uniq[ii]:
for lbl in self.nodes_uniq[ii]:
if ii == 0:
ind_prev = data[2 * ii - 0] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 2] == lbl
ind_prev = self.data[2 * ii - 0] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 2] == lbl
elif ii == self.num_flow:
ind_prev = data[2 * ii - 2] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 0] == lbl
ind_prev = self.data[2 * ii - 2] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 0] == lbl
else:
ind_prev = data[2 * ii - 2] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 2] == lbl
weight_cont = data[2 * ii + 1][ind_this & ind_prev & ind_next].sum()
weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum()
weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
ind_prev = self.data[2 * ii - 2] == lbl
ind_this = self.data[2 * ii + 0] == lbl
ind_next = self.data[2 * ii + 2] == lbl
weight_cont = self.data[2 * ii + 1][ind_this & ind_prev & ind_next].sum()
weight_only = self.data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum()
weight_stop = self.data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
weight_strt = self.data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
self.node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)
self.node_sizes[ii] = sort_dict(self.node_sizes[ii], self.sort)
weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum()

self.weight_sum[ii] = pd.Series(self.node_sizes[ii].values()).sum()


def setup(self, data):

self.data = data

num_col = len(self.data.columns)
self.data.columns = range(num_col) # force numeric column headings
self.num_stages = int(num_col / 2) # number of stages
self.num_flow = self.num_stages - 1

# sizes
col_hgt = np.empty(self.num_stages)
self.node_sizes = {}
self.nodes_uniq = {}

self.weight_labels()

# reclassify
thresh_val = self.other_dict.get("val",0)
thresh_max = self.other_dict.get("max",0)
thresh_sum = self.other_dict.get("sum",0)
for ii in range(self.num_stages):
col_hgt[ii] = weight_sum[ii] + (num_uniq[ii] - 1) * self.node_gap * max(weight_sum)
for nn, lbl in enumerate(self.data[2 * ii]):
val = self.node_sizes[ii][lbl]
if lbl is None:
continue
if val < thresh_val or val < thresh_sum * self.weight_sum[ii] or val < thresh_max * max(self.data[2 * ii + 1]):

Check failure on line 304 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

ausankey/ausankey.py:304:121: E501 Line too long (127 > 120)
self.data.iat[nn,2 * ii] = self.other_name
self.weight_labels()

# sort and calc
for ii in range(self.num_stages):
self.node_sizes[ii] = sort_dict(self.node_sizes[ii], self.sort)
col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * max(self.weight_sum)

# overall dimensions
self.plot_height = max(col_hgt)
Expand All @@ -285,7 +327,7 @@ def setup(self, data):
self.voffset[ii] = self.vscale * (col_hgt[1] - col_hgt[ii])

# labels
label_record = data[range(0, 2 * self.num_stages, 2)].to_records(index=False)
label_record = self.data[range(0, 2 * self.num_stages, 2)].to_records(index=False)
flattened = [item for sublist in label_record for item in sublist]
self.all_labels = pd.Series(flattened).unique()

Expand All @@ -302,6 +344,7 @@ def setup(self, data):
self.ax = self.ax or plt.gca()
self.ax.axis("off")


def plot_frame(self):
"""Plot frame on top/bottom edges"""

Expand All @@ -328,7 +371,7 @@ def plot_frame(self):
lw=self.frame_lw,
)

def subplot(self, ii, data):
def subplot(self, ii):
"""Subroutine for plotting horizontal sections of the Sankey plot
Some special-casing is used for plotting/labelling differently
Expand All @@ -340,14 +383,14 @@ def subplot(self, ii, data):

lastind = 4 if ii < self.num_flow - 1 else 2
labels_lr = [
data[labelind],
data[labelind + 2],
data[labelind + lastind],
self.data[labelind],
self.data[labelind + 2],
self.data[labelind + lastind],
]
weights_lr = [
data[weightind],
data[weightind + 2],
data[weightind + lastind],
self.data[weightind],
self.data[weightind + 2],
self.data[weightind + lastind],
]

nodes_lr = [
Expand Down

0 comments on commit f8b7867

Please sign in to comment.