Skip to content

Commit

Permalink
some style changes and prep for refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 17, 2024
1 parent 5e6e187 commit 08f93d3
Showing 1 changed file with 65 additions and 78 deletions.
143 changes: 65 additions & 78 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def _sankey(
Some special-casing is used for plotting/labelling differently
for the first and last cases.
"""
labelind = 2 * ii
weightind = 2 * ii + 1
labelind = 2*ii
weightind = 2*ii + 1

left = pd.Series(data[labelind])
right = pd.Series(data[labelind + 2])
Expand Down Expand Up @@ -298,47 +298,46 @@ def _sankey(

# check colours
all_labels = pd.Series([*left, *right]).unique()

missing = [label for label in all_labels if label not in color_dict]
if missing:
msg = "The color_dict parameter is missing " "values for the following labels: "
msg += "{}".format(", ".join(missing))
raise ValueError(msg)

# Determine sizes of individual strips
barsize_left = {}
barsize_right = {}
barsize = [{}, {}]
for left_label in left_labels:
barsize_left[left_label] = {}
barsize_right[left_label] = {}
barsize[0][left_label] = {}
barsize[1][left_label] = {}
for right_label in right_labels:
ind = (left == left_label) & (right == right_label)
barsize_left[left_label][right_label] = left_weight[ind].sum()
barsize_right[left_label][right_label] = right_weight[ind].sum()
barsize[0][left_label][right_label] = left_weight[ind].sum()
barsize[1][left_label][right_label] = right_weight[ind].sum()

# Determine positions of left label patches and total widths
left_widths = {}
for i, left_label in enumerate(left_labels):
tmp_dict = {}
tmp_dict["left"] = left_weight[left == left_label].sum()
barpos = [{}, {}]
for i, label in enumerate(left_labels):
barpos[0][label] = {}
barpos[0][label]["total"] = left_weight[left == label].sum()
if i == 0:
tmp_dict["bottom"] = voffset[ii]
bot = voffset[ii]
else:
tmp_dict["bottom"] = left_widths[left_labels[i - 1]]["top"] + bar_gap * plot_height
tmp_dict["top"] = tmp_dict["bottom"] + tmp_dict["left"]
left_widths[left_label] = tmp_dict
bot = barpos[0][left_labels[i-1]]["top"] + bar_gap*plot_height

Check failure on line 325 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (SIM108)

ausankey/ausankey.py:322:9: SIM108 Use ternary operator `bot = voffset[ii] if i == 0 else barpos[0][left_labels[i - 1]]["top"] + bar_gap * plot_height` instead of `if`-`else`-block

barpos[0][label]["bottom"] = bot
barpos[0][label]["top"] = barpos[0][label]["bottom"] + barpos[0][label]["total"]

# Determine positions of right label patches and total widths
right_widths = {}
for i, right_label in enumerate(right_labels):
tmp_dict = {}
tmp_dict["right"] = right_weight[right == right_label].sum()
for i, label in enumerate(right_labels):
barpos[1][label] = {}
barpos[1][label]["total"] = right_weight[right == label].sum()
if i == 0:
tmp_dict["bottom"] = voffset[ii + 1]
bot = voffset[ii+1]
else:
tmp_dict["bottom"] = right_widths[right_labels[i - 1]]["top"] + bar_gap * plot_height
tmp_dict["top"] = tmp_dict["bottom"] + tmp_dict["right"]
right_widths[right_label] = tmp_dict
bot = barpos[1][right_labels[i-1]]["top"] + bar_gap * plot_height

Check failure on line 337 in ausankey/ausankey.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (SIM108)

ausankey/ausankey.py:334:9: SIM108 Use ternary operator `bot = voffset[ii + 1] if i == 0 else barpos[1][right_labels[i - 1]]["top"] + bar_gap * plot_height` instead of `if`-`else`-block

barpos[1][label]["bottom"] = bot
barpos[1][label]["top"] = barpos[1][label]["bottom"] + barpos[1][label]["total"]

# horizontal extents of flows in each subdiagram
x_bar_width = bar_width * sub_width
Expand All @@ -350,31 +349,29 @@ def _sankey(
# Draw bars and their labels
if ii == 0: # first time
for left_label in left_labels:
lbot = left_widths[left_label]["bottom"]
lll = left_widths[left_label]["left"]
lbot = barpos[0][left_label]["bottom"]
lll = barpos[0][left_label]["total"]
ax.fill_between(
[x_left - x_bar_width, x_left],
2 * [lbot],
2 * [lbot + lll],
lbot, lbot + lll,
color=color_dict[left_label],
alpha=1,
lw=0,
snap=True,
)
ax.text(
x_left - x_label_gap - x_bar_width,
lbot + 0.5 * lll,
lbot + lll/2,
label_dict.get(left_label, left_label),
{"ha": "right", "va": "center"},
fontsize=fontsize,
)
for right_label in right_labels:
rbot = right_widths[right_label]["bottom"]
rrr = right_widths[right_label]["right"]
rbot = barpos[1][right_label]["bottom"]
rrr = barpos[1][right_label]["total"]
ax.fill_between(
[x_right, x_right + x_bar_width],
2 * [rbot],
[rbot + rrr],
rbot, rbot + rrr,
color=color_dict[right_label],
alpha=1,
lw=0,
Expand All @@ -383,68 +380,60 @@ def _sankey(
if ii < num_flow - 1: # inside labels
ax.text(
x_right + x_label_gap + x_bar_width,
rbot + 0.5 * rrr,
rbot + rrr/2,
label_dict.get(right_label, right_label),
{"ha": "left", "va": "center"},
fontsize=fontsize,
)
if ii == num_flow - 1: # last time
ax.text(
x_right + x_label_gap + x_bar_width,
rbot + 0.5 * rrr,
rbot + rrr/2,
label_dict.get(right_label, right_label),
{"ha": "left", "va": "center"},
fontsize=fontsize,
)

# "titles"
if titles is not None:

y_title_gap = title_gap*plot_height

# leftmost title
if ii == 0:
xt = x_left - x_bar_width / 2
xt = x_left - x_bar_width/2
if title_side in ("top", "both"):
yt = title_gap * plot_height + left_widths[left_label]["top"]
yt = y_title_gap + barpos[0][left_label]["top"]
va = "bottom"
ax.text(
xt,
yt,
titles[ii],
xt, yt, titles[ii],
{"ha": "center", "va": va},
fontsize=fontsize,
)

if title_side in ("bottom", "both"):
yt = voffset[ii] - title_gap * plot_height
yt = voffset[ii] - y_title_gap
va = "top"

ax.text(
xt,
yt,
titles[ii],
xt, yt, titles[ii],
{"ha": "center", "va": va},
fontsize=fontsize,
)

# all other titles
xt = x_right + x_bar_width / 2
xt = x_right + x_bar_width/2
if title_side in ("top", "both"):
yt = title_gap * plot_height + right_widths[right_label]["top"]

yt = y_title_gap + barpos[1][right_label]["top"]
ax.text(
xt,
yt,
titles[ii + 1],
xt, yt, titles[ii+1],
{"ha": "center", "va": "bottom"},
fontsize=fontsize,
)

if title_side in ("bottom", "both"):
yt = voffset[ii + 1] - title_gap * plot_height

yt = voffset[ii+1] - y_title_gap
ax.text(
xt,
yt,
titles[ii + 1],
xt, yt, titles[ii+1],
{"ha": "center", "va": "top"},
fontsize=fontsize,
)
Expand All @@ -458,18 +447,18 @@ def _sankey(
if not any(lind & rind):
continue

lbot = left_widths[left_label]["bottom"]
rbot = right_widths[right_label]["bottom"]
lbar = barsize_left[left_label][right_label]
rbar = barsize_right[left_label][right_label]
lbot = barpos[0][left_label]["bottom"]
rbot = barpos[1][right_label]["bottom"]
lbar = barsize[0][left_label][right_label]
rbar = barsize[1][left_label][right_label]

ys_d = create_curve(lbot, rbot)
ys_u = create_curve(lbot + lbar, rbot + rbar)

# Update bottom edges at each label
# so next strip starts at the right place
left_widths[left_label]["bottom"] += lbar
right_widths[right_label]["bottom"] += rbar
barpos[0][left_label]["bottom"] += lbar
barpos[1][right_label]["bottom"] += rbar

xx = np.linspace(x_left, x_right, len(ys_d))
cc = combine_colours(color_dict[left_label], color_dict[right_label], len(ys_d))
Expand Down Expand Up @@ -517,21 +506,19 @@ def check_data_matches_labels(labels, data, side):
that the order of labels is still consistent
with the labels in the data.
"""
if len(labels) > 0:
if isinstance(data, list):
data = set(data)
if isinstance(data, pd.Series):
data = set(data.unique().tolist())
if isinstance(labels, list):
labels = set(labels)
if labels != data:
msg = "\n"
maxlen = 20
if len(labels) <= maxlen:
msg = "Labels: " + ",".join(labels) + "\n"
if len(data) < maxlen:
msg += "Data: " + ",".join(data)
raise LabelMismatchError(side, msg)

if len(labels) == 0:
msg = "Length of labels equals zero?"
raise LabelMismatchError(side, msg)

if set(labels) != set(data):
msg = "\n"
maxlen = 20
if len(labels) <= maxlen:
msg += "Labels: " + ",".join(labels) + "\n"
if len(data) < maxlen:
msg += "Data: " + ",".join(data)
raise LabelMismatchError(side, msg)


###########################################
Expand Down

0 comments on commit 08f93d3

Please sign in to comment.