Skip to content

Commit

Permalink
tighten up code and bug fix!
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Apr 16, 2024
1 parent 44746b9 commit 161b50a
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def setup(self, data):
self.num_flow = self.num_stages - 1

# sizes
col_hgt = np.empty(self.num_stages)
self.node_sizes = {}
self.node_list = {}
self.nodes_uniq = {}
Expand All @@ -367,32 +366,38 @@ def setup(self, data):
self.node_pos_bot = {}
self.node_pos_top = {}

for ii in range(self.num_stages):
uni = pd.Series(self.data[2 * ii]).unique()
self.nodes_uniq[ii] = pd.Series(uni).dropna()

# weight and reclassify
self.weight_labels()
for ii in range(self.num_stages):
for nn, lbl in enumerate(self.data[2 * ii]):
val = self.node_sizes[ii][lbl]
if lbl is not None and (
val < self.other_thresh_val
or val < self.other_thresh_sum * self.weight_sum[ii]
or val < self.other_thresh_max * max(self.data[2 * ii + 1])
):
self.data.iat[nn, 2 * ii] = self.other_name
if lbl is not None:
val = self.node_sizes[ii][lbl]
if (
val < self.other_thresh_val
or val < self.other_thresh_sum * self.weight_sum[ii]
or val < self.other_thresh_max * max(self.data[2 * ii + 1])
):
self.data.iat[nn, 2 * ii] = self.other_name
self.weight_labels()

# sort and calc
self.plot_height_nom = max(self.weight_sum)
for ii in range(self.num_stages):
self.node_sizes[ii] = self.sort_node_sizes(self.node_sizes[ii], self.sort)
col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * self.plot_height_nom
self.node_list[ii] = self.sort_nodes(self.data[2 * ii], self.node_sizes[ii])

# offsets for alignment
vscale_dict = {"top": 1, "center": 0.5, "bottom": 0}
self.vscale = vscale_dict.get(self.valign, 0)
self.voffset = np.empty(self.num_stages)
col_hgt = np.empty(self.num_stages)
for ii in range(self.num_stages):
self.voffset[ii] = self.vscale * (col_hgt[1] - col_hgt[ii])
col_hgt[ii] = self.weight_sum[ii] + (len(self.nodes_uniq[ii]) - 1) * self.node_gap * self.plot_height_nom
self.voffset[ii] = self.vscale * (col_hgt[0] - col_hgt[ii])

# overall dimensions
self.plot_height = max(col_hgt)
Expand Down Expand Up @@ -482,9 +487,6 @@ def weight_labels(self):
"""Calculates sizes of each node, taking into account discontinuities"""
self.weight_sum = np.empty(self.num_stages)

for ii in range(self.num_stages):
self.nodes_uniq[ii] = pd.Series(self.data[2 * ii]).unique()

for ii in range(self.num_stages):
self.node_sizes[ii] = {}
for lbl in self.nodes_uniq[ii]:
Expand Down

0 comments on commit 161b50a

Please sign in to comment.