Skip to content

Commit

Permalink
bunch of little tidy ups
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 29, 2024
1 parent 846ab4b commit 565a3b6
Showing 1 changed file with 23 additions and 60 deletions.
83 changes: 23 additions & 60 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,18 @@ def __init__(
self.node_width = node_width
self.node_gap = node_gap
self.node_alpha = node_alpha
self.node_edge = node_edge
self.color_dict = color_dict
self.node_edge = node_edge or False
self.color_dict = color_dict or {}
self.colormap = colormap
self.flow_edge = flow_edge
self.flow_edge = flow_edge or False
self.flow_alpha = flow_alpha
self.fontcolor = fontcolor
self.fontsize = fontsize
self.fontfamily = fontfamily
self.frame_side = frame_side
self.frame_gap = frame_gap
self.frame_color = frame_color
self.label_dict = label_dict
self.label_dict = label_dict or {}
self.label_width = label_width
self.label_gap = label_gap
self.label_loc = label_loc
Expand All @@ -212,13 +212,14 @@ def __init__(
self.node_lw = node_lw
self.frame_lw = frame_lw
self.titles = titles
self.title_font = title_font
self.title_font = title_font or {"fontweight": "bold"}
self.title_gap = title_gap
self.title_loc = title_loc
self.title_side = title_side
self.sort = sort
self.valign = valign


def setup(self, data):
num_col = len(data.columns)
data.columns = range(num_col) # force numeric column headings
Expand All @@ -239,7 +240,7 @@ def setup(self, data):
self.node_sizes[ii] = {}
for lbl in nodes_uniq[ii]:
if ii == 0:
ind_prev = data[2 * ii + 0] == lbl
ind_prev = data[2 * ii - 0] == lbl
ind_this = data[2 * ii + 0] == lbl
ind_next = data[2 * ii + 2] == lbl
elif ii == self.num_flow:
Expand Down Expand Up @@ -271,46 +272,30 @@ def setup(self, data):
)

# offsets for alignment
vscale_dict = {"top": 1, "center": 0.5, "bottom": 0}
self.vscale = vscale_dict.get(self.valign,0)
self.voffset = np.empty(self.num_stages)
if self.valign == "top":
vscale = 1
elif self.valign == "center":
vscale = 0.5
else: # bottom, or undefined
vscale = 0

for ii in range(self.num_stages):
self.voffset[ii] = vscale * (col_hgt[1] - col_hgt[ii])
self.voffset[ii] = self.vscale * (col_hgt[1] - col_hgt[ii])

# labels
label_record = data[range(0, 2 * self.num_stages, 2)].to_records(index=False)
flattened = [item for sublist in label_record for item in sublist]
flatcat = pd.Series(flattened).unique()
self.all_labels = pd.Series(flattened).unique()

# If no color_dict given, make one
color_dict_orig = self.color_dict or {}
color_dict_new = {}
cmap = plt.cm.get_cmap(self.colormap)
color_palette = cmap(np.linspace(0, 1, len(flatcat)))
for i, label in enumerate(flatcat):
color_dict_new[label] = color_dict_orig.get(label, color_palette[i])
check_colors_match_labels(flatcat, color_dict_new)
color_palette = cmap(np.linspace(0, 1, len(self.all_labels)))
for i, label in enumerate(self.all_labels):
color_dict_new[label] = self.color_dict.get(label, color_palette[i])
check_colors_match_labels(self.all_labels, color_dict_new)
self.color_dict = color_dict_new

# initialise plot
self.ax = self.ax or plt.gca()
self.ax.axis("off")

if self.node_edge is None:
self.node_edge = False
if self.flow_edge is None:
self.flow_edge = False
if self.title_font is None:
self.title_font = {"fontweight": "bold"}
if self.label_dict is None:
self.label_dict = {}
if self.label_font is None:
self.label_font = {}

def plot_frame(self):
"""Plot frame on top/bottom edges"""
Expand Down Expand Up @@ -348,28 +333,13 @@ def subplot(self, ii, data):
labelind = 2 * ii
weightind = 2 * ii + 1

if ii < self.num_flow - 1:
labels_lr = [
pd.Series(data[labelind]),
pd.Series(data[labelind + 2]),
pd.Series(data[labelind + 4]),
]
weights_lr = [
pd.Series(data[weightind]),
pd.Series(data[weightind + 2]),
pd.Series(data[weightind + 4]),
]
else:
labels_lr = [
pd.Series(data[labelind]),
pd.Series(data[labelind + 2]),
pd.Series(data[labelind + 2]),
]
weights_lr = [
pd.Series(data[weightind]),
pd.Series(data[weightind + 2]),
pd.Series(data[weightind + 2]),
]
lastind = 4 if ii < self.num_flow - 1 else 2
labels_lr = [
data[labelind], data[labelind + 2], data[labelind + lastind],
]
weights_lr = [
data[weightind], data[weightind + 2], data[weightind + lastind],
]

nodes_lr = [
sort_nodes(labels_lr[0], self.node_sizes[ii]),
Expand All @@ -389,13 +359,6 @@ def subplot(self, ii, data):
# Determine vertical positions of nodes
y_node_gap = self.node_gap * self.plot_height

if self.valign == "top":
vscale = 1
elif self.valign == "center":
vscale = 0.5
else: # bottom, or undefined
vscale = 0

node_voffset = [{}, {}]
node_pos_bot = [{}, {}]
node_pos_top = [{}, {}]
Expand All @@ -404,7 +367,7 @@ def subplot(self, ii, data):
for i, label in enumerate(nodes_lr[lr]):
node_height = self.node_sizes[ii + lr][label]
this_side_height = weights_lr[lr][labels_lr[lr] == label].sum()
node_voffset[lr][label] = vscale * (node_height - this_side_height)
node_voffset[lr][label] = self.vscale * (node_height - this_side_height)
next_bot = node_pos_top[lr][nodes_lr[lr][i - 1]] + y_node_gap if i > 0 else 0
node_pos_bot[lr][label] = self.voffset[ii + lr] if i == 0 else next_bot
node_pos_top[lr][label] = node_pos_bot[lr][label] + node_height
Expand Down

0 comments on commit 565a3b6

Please sign in to comment.