Skip to content

Commit

Permalink
precalculate node sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 22, 2024
1 parent 7f64f9b commit 92efe90
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,28 @@ def sankey(
nodes_uniq[ii] = pd.Series(data[2 * ii]).unique()
num_uniq[ii] = len(nodes_uniq[ii])

# for ii in range(num_side):
# node_sizes[ii] = {}
# for lbl in nodes_uniq[ii]:
# if ii == 0:
# ind_prev = data[2 * ii + 0] == lbl
# ind_this = data[2 * ii + 0] == lbl
# ind_next = data[2 * ii + 2] == lbl
# weight_cont = data[2 * ii + 1][ind_this & ind_prev & ind_next].sum()
# weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum()
# weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
# weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
# node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)

# print(node_sizes)
for ii in range(num_side):
node_sizes[ii] = {}
for lbl in nodes_uniq[ii]:
if ii == 0:
ind_prev = data[2 * ii + 0] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 2] == lbl
elif ii == num_flow:
ind_prev = data[2 * ii - 2] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 0] == lbl
else:
ind_prev = data[2 * ii - 2] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 2] == lbl
weight_cont = data[2 * ii + 1][ind_this & ind_prev & ind_next].sum()
weight_only = data[2 * ii + 1][ind_this & ~ind_prev & ~ind_next].sum()
weight_stop = data[2 * ii + 1][ind_this & ind_prev & ~ind_next].sum()
weight_strt = data[2 * ii + 1][ind_this & ~ind_prev & ind_next].sum()
node_sizes[ii][lbl] = weight_cont + weight_only + max(weight_stop, weight_strt)
node_sizes[ii] = sort_dict(node_sizes[ii], sort)

Check failure on line 186 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F821)

ausankey/ausankey.py:186:26: F821 Undefined name `sort_dict`

for ii in range(num_side):
if ii == 0:
ind_prev = data[2 * ii + 1].notnull()
Expand Down Expand Up @@ -265,6 +273,7 @@ def sankey(
ii,
num_flow,
data,
node_sizes=node_sizes,
titles=titles,
title_gap=title_gap,
title_side=title_side,
Expand Down Expand Up @@ -303,6 +312,7 @@ def _sankey(
flow_edge=None,
fontsize=None,
frame_gap=None,
node_sizes=None,

Check failure on line 315 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ARG001)

ausankey/ausankey.py:315:5: ARG001 Unused function argument: `node_sizes`
titles=None,
title_gap=None,
title_side=None,
Expand Down Expand Up @@ -595,12 +605,10 @@ def weighted_sort(lbl, wgt, sorting):
else:
s = 0

arr = {}
for uniq in lbl.unique():
arr[uniq] = wgt[lbl == uniq].sum()

Check failure on line 609 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F821)

ausankey/ausankey.py:609:9: F821 Undefined name `arr`

sort_arr = sorted(
arr.items(),
key=lambda item: s * item[1],
# sorting = 0,1,-1 affects this
)
Expand Down

0 comments on commit 92efe90

Please sign in to comment.