diff --git a/ausankey/ausankey.py b/ausankey/ausankey.py
index d7f3b6b..92940ea 100644
--- a/ausankey/ausankey.py
+++ b/ausankey/ausankey.py
@@ -266,8 +266,8 @@ def _sankey(
     Some special-casing is used for plotting/labelling differently
     for the first and last cases.
     """
-    labelind = 2 * ii
-    weightind = 2 * ii + 1
+    labelind = 2*ii
+    weightind = 2*ii + 1
 
     left = pd.Series(data[labelind])
     right = pd.Series(data[labelind + 2])
@@ -298,7 +298,6 @@ def _sankey(
 
     # check colours
     all_labels = pd.Series([*left, *right]).unique()
-
     missing = [label for label in all_labels if label not in color_dict]
     if missing:
         msg = "The color_dict parameter is missing " "values for the following labels: "
@@ -306,39 +305,39 @@ def _sankey(
         raise ValueError(msg)
 
     # Determine sizes of individual strips
-    barsize_left = {}
-    barsize_right = {}
+    barsize = [{}, {}]
     for left_label in left_labels:
-        barsize_left[left_label] = {}
-        barsize_right[left_label] = {}
+        barsize[0][left_label] = {}
+        barsize[1][left_label] = {}
         for right_label in right_labels:
             ind = (left == left_label) & (right == right_label)
-            barsize_left[left_label][right_label] = left_weight[ind].sum()
-            barsize_right[left_label][right_label] = right_weight[ind].sum()
+            barsize[0][left_label][right_label] = left_weight[ind].sum()
+            barsize[1][left_label][right_label] = right_weight[ind].sum()
 
     # Determine positions of left label patches and total widths
-    left_widths = {}
-    for i, left_label in enumerate(left_labels):
-        tmp_dict = {}
-        tmp_dict["left"] = left_weight[left == left_label].sum()
+    barpos = [{}, {}]
+    for i, label in enumerate(left_labels):
+        barpos[0][label] = {}
+        barpos[0][label]["total"] = left_weight[left == label].sum()
         if i == 0:
-            tmp_dict["bottom"] = voffset[ii]
+            bot = voffset[ii]
         else:
-            tmp_dict["bottom"] = left_widths[left_labels[i - 1]]["top"] + bar_gap * plot_height
-        tmp_dict["top"] = tmp_dict["bottom"] + tmp_dict["left"]
-        left_widths[left_label] = tmp_dict
+            bot = barpos[0][left_labels[i-1]]["top"] + bar_gap*plot_height
+        
+        barpos[0][label]["bottom"] = bot
+        barpos[0][label]["top"] = barpos[0][label]["bottom"] + barpos[0][label]["total"]
 
     # Determine positions of right label patches and total widths
-    right_widths = {}
-    for i, right_label in enumerate(right_labels):
-        tmp_dict = {}
-        tmp_dict["right"] = right_weight[right == right_label].sum()
+    for i, label in enumerate(right_labels):
+        barpos[1][label] = {}
+        barpos[1][label]["total"] = right_weight[right == label].sum()
         if i == 0:
-            tmp_dict["bottom"] = voffset[ii + 1]
+            bot = voffset[ii+1]
         else:
-            tmp_dict["bottom"] = right_widths[right_labels[i - 1]]["top"] + bar_gap * plot_height
-        tmp_dict["top"] = tmp_dict["bottom"] + tmp_dict["right"]
-        right_widths[right_label] = tmp_dict
+            bot = barpos[1][right_labels[i-1]]["top"] + bar_gap * plot_height
+        
+        barpos[1][label]["bottom"] = bot
+        barpos[1][label]["top"] = barpos[1][label]["bottom"] + barpos[1][label]["total"]
 
     # horizontal extents of flows in each subdiagram
     x_bar_width = bar_width * sub_width
@@ -350,12 +349,11 @@ def _sankey(
     # Draw bars and their labels
     if ii == 0:  # first time
         for left_label in left_labels:
-            lbot = left_widths[left_label]["bottom"]
-            lll = left_widths[left_label]["left"]
+            lbot = barpos[0][left_label]["bottom"]
+            lll = barpos[0][left_label]["total"]
             ax.fill_between(
                 [x_left - x_bar_width, x_left],
-                2 * [lbot],
-                2 * [lbot + lll],
+                lbot, lbot + lll,
                 color=color_dict[left_label],
                 alpha=1,
                 lw=0,
@@ -363,18 +361,17 @@ def _sankey(
             )
             ax.text(
                 x_left - x_label_gap - x_bar_width,
-                lbot + 0.5 * lll,
+                lbot + lll/2,
                 label_dict.get(left_label, left_label),
                 {"ha": "right", "va": "center"},
                 fontsize=fontsize,
             )
     for right_label in right_labels:
-        rbot = right_widths[right_label]["bottom"]
-        rrr = right_widths[right_label]["right"]
+        rbot = barpos[1][right_label]["bottom"]
+        rrr = barpos[1][right_label]["total"]
         ax.fill_between(
             [x_right, x_right + x_bar_width],
-            2 * [rbot],
-            [rbot + rrr],
+            rbot, rbot + rrr,
             color=color_dict[right_label],
             alpha=1,
             lw=0,
@@ -383,7 +380,7 @@ def _sankey(
         if ii < num_flow - 1:  # inside labels
             ax.text(
                 x_right + x_label_gap + x_bar_width,
-                rbot + 0.5 * rrr,
+                rbot + rrr/2,
                 label_dict.get(right_label, right_label),
                 {"ha": "left", "va": "center"},
                 fontsize=fontsize,
@@ -391,7 +388,7 @@ def _sankey(
         if ii == num_flow - 1:  # last time
             ax.text(
                 x_right + x_label_gap + x_bar_width,
-                rbot + 0.5 * rrr,
+                rbot + rrr/2,
                 label_dict.get(right_label, right_label),
                 {"ha": "left", "va": "center"},
                 fontsize=fontsize,
@@ -399,52 +396,44 @@ def _sankey(
 
     # "titles"
     if titles is not None:
+
+        y_title_gap = title_gap*plot_height
+
         # leftmost title
         if ii == 0:
-            xt = x_left - x_bar_width / 2
+            xt = x_left - x_bar_width/2
             if title_side in ("top", "both"):
-                yt = title_gap * plot_height + left_widths[left_label]["top"]
+                yt = y_title_gap + barpos[0][left_label]["top"]
                 va = "bottom"
                 ax.text(
-                    xt,
-                    yt,
-                    titles[ii],
+                    xt, yt, titles[ii],
                     {"ha": "center", "va": va},
                     fontsize=fontsize,
                 )
 
             if title_side in ("bottom", "both"):
-                yt = voffset[ii] - title_gap * plot_height
+                yt = voffset[ii] - y_title_gap
                 va = "top"
-
                 ax.text(
-                    xt,
-                    yt,
-                    titles[ii],
+                    xt, yt, titles[ii],
                     {"ha": "center", "va": va},
                     fontsize=fontsize,
                 )
 
         # all other titles
-        xt = x_right + x_bar_width / 2
+        xt = x_right + x_bar_width/2
         if title_side in ("top", "both"):
-            yt = title_gap * plot_height + right_widths[right_label]["top"]
-
+            yt = y_title_gap + barpos[1][right_label]["top"]
             ax.text(
-                xt,
-                yt,
-                titles[ii + 1],
+                xt, yt, titles[ii+1],
                 {"ha": "center", "va": "bottom"},
                 fontsize=fontsize,
             )
 
         if title_side in ("bottom", "both"):
-            yt = voffset[ii + 1] - title_gap * plot_height
-
+            yt = voffset[ii+1] - y_title_gap
             ax.text(
-                xt,
-                yt,
-                titles[ii + 1],
+                xt, yt, titles[ii+1],
                 {"ha": "center", "va": "top"},
                 fontsize=fontsize,
             )
@@ -458,18 +447,18 @@ def _sankey(
             if not any(lind & rind):
                 continue
 
-            lbot = left_widths[left_label]["bottom"]
-            rbot = right_widths[right_label]["bottom"]
-            lbar = barsize_left[left_label][right_label]
-            rbar = barsize_right[left_label][right_label]
+            lbot = barpos[0][left_label]["bottom"]
+            rbot = barpos[1][right_label]["bottom"]
+            lbar = barsize[0][left_label][right_label]
+            rbar = barsize[1][left_label][right_label]
 
             ys_d = create_curve(lbot, rbot)
             ys_u = create_curve(lbot + lbar, rbot + rbar)
 
             # Update bottom edges at each label
             # so next strip starts at the right place
-            left_widths[left_label]["bottom"] += lbar
-            right_widths[right_label]["bottom"] += rbar
+            barpos[0][left_label]["bottom"] += lbar
+            barpos[1][right_label]["bottom"] += rbar
 
             xx = np.linspace(x_left, x_right, len(ys_d))
             cc = combine_colours(color_dict[left_label], color_dict[right_label], len(ys_d))
@@ -517,21 +506,19 @@ def check_data_matches_labels(labels, data, side):
     that the order of labels is still consistent
     with the labels in the data.
     """
-    if len(labels) > 0:
-        if isinstance(data, list):
-            data = set(data)
-        if isinstance(data, pd.Series):
-            data = set(data.unique().tolist())
-        if isinstance(labels, list):
-            labels = set(labels)
-        if labels != data:
-            msg = "\n"
-            maxlen = 20
-            if len(labels) <= maxlen:
-                msg = "Labels: " + ",".join(labels) + "\n"
-            if len(data) < maxlen:
-                msg += "Data: " + ",".join(data)
-            raise LabelMismatchError(side, msg)
+
+    if len(labels) == 0:
+        msg = "Length of labels equals zero?"
+        raise LabelMismatchError(side, msg)
+
+    if set(labels) != set(data):
+        msg = "\n"
+        maxlen = 20
+        if len(labels) <= maxlen:
+            msg += "Labels: " + ",".join(labels) + "\n"
+        if len(data) < maxlen:
+            msg += "Data: " + ",".join(data)
+        raise LabelMismatchError(side, msg)
 
 
 ###########################################