Skip to content

Commit

Permalink
these loops were annoying me for ages
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Apr 12, 2024
1 parent 690f8e9 commit ebc497d
Showing 1 changed file with 58 additions and 58 deletions.
116 changes: 58 additions & 58 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,23 @@ def setup(self, data):
self.x_lr = {}
self.nodesize_l = {}
self.nodesize_r = {}
self.node_pairs = {}
for ii in range(self.num_flow):
x_left = (
self.x_node_width + self.x_label_gap + self.x_label_width + ii * (self.sub_width + self.x_node_width)
)
self.x_lr[ii] = (x_left, x_left + self.sub_width)
self.nodesize_l[ii] = {}
self.nodesize_r[ii] = {}
self.node_pairs[ii] = []
for lbl_l in self.node_list[ii]:
self.nodesize_l[ii][lbl_l] = {}
self.nodesize_r[ii][lbl_l] = {}
for lbl_r in self.node_list[ii + 1]:
ind = (self.data[2 * ii] == lbl_l) & (self.data[2 * ii + 2] == lbl_r)
if not any(ind):
continue
self.node_pairs[ii].append((lbl_l,lbl_r))
self.nodesize_l[ii][lbl_l][lbl_r] = self.data[2 * ii + 1][ind].sum()
self.nodesize_r[ii][lbl_l][lbl_r] = self.data[2 * ii + 3][ind].sum()

Expand Down Expand Up @@ -605,68 +610,63 @@ def subplot(self, ii):

# Plot flows

for lbl_l in self.node_list[ii]:
for lbl_r in self.node_list[ii + 1]:
lind = self.data[2 * ii] == lbl_l
rind = self.data[2 * ii + 2] == lbl_r
if not any(lind & rind):
continue

lbot = self.node_pos_voffset[ii][0][lbl_l] + self.node_pos_bot[ii][0][lbl_l]
rbot = self.node_pos_voffset[ii][1][lbl_r] + self.node_pos_bot[ii][1][lbl_r]
llen = self.nodesize_l[ii][lbl_l][lbl_r]
rlen = self.nodesize_r[ii][lbl_l][lbl_r]
bot_lr = [lbot, rbot]
len_lr = [llen, rlen]

ys_d = self.create_curve(lbot, rbot)
ys_u = self.create_curve(lbot + llen, rbot + rlen)

# Update bottom edges at each label
# so next strip starts at the right place
self.node_pos_bot[ii][0][lbl_l] += llen
self.node_pos_bot[ii][1][lbl_r] += rlen

xx = np.linspace(x_lr[0], x_lr[1], len(ys_d))
cc = self.combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d))

for jj in range(len(ys_d) - 1):
self.draw_flow(
xx[[jj, jj + 1]],
ys_d[[jj, jj + 1]],
ys_u[[jj, jj + 1]],
cc[:, jj],
)
for lbl_l, lbl_r in self.node_pairs[ii]:

ha = ["left", "right"]
sides = []
if ii == 0:
ind = 0
elif ii == self.num_flow - 1:
ind = 2
else:
ind = 1
if self.value_loc[ind] in ("left", "both"):
sides.append(0)
if self.value_loc[ind] in ("right", "both"):
sides.append(1)
for lr in sides:
val = len_lr[lr]
if not (
val < self.value_thresh_val
or val < self.value_thresh_sum * self.weight_sum[ii + lr]
or val < self.value_thresh_max * max(self.data[2 * ii + 1])
):
self.draw_value(
x_lr[lr] + (1 - 2 * lr) * self.x_value_gap,
bot_lr[lr] + len_lr[lr] / 2,
val,
ha[lr],
)
lbot = self.node_pos_voffset[ii][0][lbl_l] + self.node_pos_bot[ii][0][lbl_l]
rbot = self.node_pos_voffset[ii][1][lbl_r] + self.node_pos_bot[ii][1][lbl_r]
llen = self.nodesize_l[ii][lbl_l][lbl_r]
rlen = self.nodesize_r[ii][lbl_l][lbl_r]
bot_lr = [lbot, rbot]
len_lr = [llen, rlen]

ys_d = self.create_curve(lbot, rbot)
ys_u = self.create_curve(lbot + llen, rbot + rlen)

# Update bottom edges at each label
# so next strip starts at the right place
self.node_pos_bot[ii][0][lbl_l] += llen
self.node_pos_bot[ii][1][lbl_r] += rlen

xx = np.linspace(x_lr[0], x_lr[1], len(ys_d))
cc = self.combine_colours(self.color_dict[lbl_l], self.color_dict[lbl_r], len(ys_d))

for jj in range(len(ys_d) - 1):
self.draw_flow(
xx[[jj, jj + 1]],
ys_d[[jj, jj + 1]],
ys_u[[jj, jj + 1]],
cc[:, jj],
)

ha = ["left", "right"]
sides = []
if ii == 0:
ind = 0
elif ii == self.num_flow - 1:
ind = 2
else:
ind = 1
if self.value_loc[ind] in ("left", "both"):
sides.append(0)
if self.value_loc[ind] in ("right", "both"):
sides.append(1)
for lr in sides:
val = len_lr[lr]
if not (
val < self.value_thresh_val
or val < self.value_thresh_sum * self.weight_sum[ii + lr]
or val < self.value_thresh_max * max(self.data[2 * ii + 1])
):
self.draw_value(
x_lr[lr] + (1 - 2 * lr) * self.x_value_gap,
bot_lr[lr] + len_lr[lr] / 2,
val,
ha[lr],
)

# Place "titles"
if self.titles is not None:
last_label = [lbl_l, lbl_r]
last_label = self.node_pairs[ii][-1]
title_x = [x_lr[0] - self.x_node_width / 2, x_lr[1] + self.x_node_width / 2]

for lr in [0, 1] if ii == 0 else [1]:
Expand Down

0 comments on commit ebc497d

Please sign in to comment.