Skip to content

Commit

Permalink
Merge pull request #223 from tpvasconcelos/dedup-y-labels
Browse files Browse the repository at this point in the history
Deduplicate y-axis labels
  • Loading branch information
tpvasconcelos authored Oct 15, 2024
2 parents 0587745 + a9e9252 commit 26c8e5e
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 25 deletions.
4 changes: 4 additions & 0 deletions docs/reference/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ This document outlines the list of changes to ridgeplot between each release. Fo
Unreleased changes
------------------

### Improvements

- Remove duplicated labels when plotting multiple traces on the same y-axis row ({gh-pr}`223`)

### Documentation

- Update and improve the "Contributing" guide ({gh-pr}`218` and {gh-pr}`221`)
Expand Down
33 changes: 14 additions & 19 deletions src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ridgeplot._colors import apply_alpha, get_color, get_colorscale, validate_colorscale
from ridgeplot._types import CollectionL1, CollectionL2
from ridgeplot._utils import normalise_min_max
from ridgeplot._utils import normalise_min_max, normalise_row_attrs, ordered_dedup

if TYPE_CHECKING:
from collections.abc import Collection
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(
colorscale: str | ColorScale,
coloralpha: float | None,
colormode: Colormode,
labels: LabelsArray | None,
trace_labels: LabelsArray | None,
linewidth: float,
spacing: float,
show_yticklabels: bool,
Expand All @@ -179,15 +179,18 @@ def __init__(
if coloralpha is not None:
coloralpha = float(coloralpha)

if labels is None:
if trace_labels is None:
ids = iter(range(1, n_traces + 1))
labels = [[f"Trace {next(ids)}" for _ in row] for row in densities]
trace_labels = [[f"Trace {next(ids)}" for _ in row] for row in densities]
else:
trace_labels = normalise_row_attrs(trace_labels, densities=densities)

self.densities: Densities = densities
self.colorscale: ColorScale = colorscale
self.coloralpha: float | None = coloralpha
self.colormode = colormode
self.labels: LabelsArray = labels
self.trace_labels: LabelsArray = trace_labels
self.y_labels: LabelsArray = [ordered_dedup(row) for row in trace_labels]
self.linewidth: float = float(linewidth)
self.spacing: float = float(spacing)
self.show_yticklabels: bool = bool(show_yticklabels)
Expand All @@ -203,7 +206,7 @@ def __init__(
self.colors: ColorsArray = self.pre_compute_colors()

@property
def colormode_maps(self) -> dict[str, Callable[[], MidpointsArray]]:
def colormode_maps(self) -> dict[Colormode, Callable[[], MidpointsArray]]:
return {
"row-index": self._compute_midpoints_row_index,
"trace-index": self._compute_midpoints_trace_index,
Expand Down Expand Up @@ -274,7 +277,7 @@ def update_layout(self, y_ticks: list[float]) -> None:
self.fig.update_yaxes(
showticklabels=self.show_yticklabels,
tickvals=y_ticks,
ticktext=self.labels,
ticktext=self.y_labels,
**axes_common,
)
x_padding = self.xpad * (self.x_max - self.x_min)
Expand Down Expand Up @@ -362,18 +365,10 @@ def _get_color(mp: float) -> str:

def make_figure(self) -> go.Figure:
y_ticks = []
for i, (row, labels, colors) in enumerate(zip(self.densities, self.labels, self.colors)):
n_traces = len(row)
n_labels = len(labels)
if n_traces != n_labels:
# TODO: This should be handled upstream
if n_labels == 1:
labels = list(labels) * n_traces # noqa: PLW2901
else:
raise ValueError(
f"Mismatch between number of traces ({n_traces}) and "
f"number of labels ({n_labels}) for row {i}."
)
for i, (row, labels, colors) in enumerate(
# TODO: Use strict=True in Python>=3.10
zip(self.densities, self.trace_labels, self.colors)
):
# y_shifted is the y-origin for the new trace
y_shifted = -i * float(self.y_max * self.spacing)
y_ticks.append(y_shifted)
Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def ridgeplot(

ridgeplot_figure_factory = RidgePlotFigureFactory(
densities=densities,
labels=labels,
trace_labels=labels,
colorscale=colorscale,
coloralpha=coloralpha,
colormode=colormode,
Expand Down
4 changes: 2 additions & 2 deletions src/ridgeplot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def _is_numeric(obj: Any) -> bool:
... [(3, 0), (4, 1), (5, 2), (6, 1), (7, 0)], # Trace 3
... ],
... [ # Row 2
... [(-2, 0), (-1, 1), (0, 0)], # Trace 5
... [(0, 0), (1, 1), (2, 1), (3, 0)], # Trace 6
... [(-2, 0), (-1, 1), (0, 0)], # Trace 4
... [(0, 0), (1, 1), (2, 1), (3, 0)], # Trace 5
... ],
... ]
Expand Down
100 changes: 99 additions & 1 deletion src/ridgeplot/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Iterator
from typing import Any, Callable

from ridgeplot._types import Numeric
from ridgeplot._types import CollectionL2, Densities, Numeric


def normalise_min_max(val: Numeric, min_: Numeric, max_: Numeric) -> float:
Expand Down Expand Up @@ -117,6 +117,104 @@ def _get_dim_length(obj: Any) -> int:
return tuple(shape)


_V = TypeVar("_V")


def ordered_dedup(seq: Collection[_V]) -> list[_V]:
"""Return a list with the elements of ``seq`` in the order they first appear.
Parameters
----------
seq
A sequence.
Returns
-------
list
A list with the elements of ``seq`` in the order they first appear.
Examples
--------
>>> ordered_dedup([1, 2, 3, 1, 2, 3, 1, 2, 3])
[1, 2, 3]
>>> ordered_dedup([1, 2, 3, 4, 5, 6])
[1, 2, 3, 4, 5, 6]
>>> ordered_dedup([1, 1, 1, 1, 1, 1, 1, 1, 1])
[1]
>>> ordered_dedup([1, 2, 3, 3, 2, 1])
[1, 2, 3]
>>> ordered_dedup([3, 1, 2, 4, 2, 4, 5])
[3, 1, 2, 4, 5]
"""
return list(dict.fromkeys(seq))


def normalise_row_attrs(attrs: CollectionL2[_V], densities: Densities) -> CollectionL2[_V]:
"""Normalise the attributes over a Density array such that the number of
attributes matches the number of traces in each row.
Parameters
----------
attrs
The attributes collection to normalise.
densities
The densities array to normalise the attributes over.
Returns
-------
CollectionL2
The normalised attributes collection.
Raises
------
ValueError
If the number of traces does not match the number of attributes for a
row.
Examples
--------
>>> densities = [
... [ # Row 1
... [(0, 0), (1, 1), (2, 0)], # Trace 1
... [(1, 0), (2, 1), (3, 2), (4, 1)], # Trace 2
... [(3, 0), (4, 1), (5, 2), (6, 1), (7, 0)], # Trace 3
... ],
... [ # Row 2
... [(-2, 0), (-1, 1), (0, 0)], # Trace 4
... [(0, 0), (1, 1), (2, 1), (3, 0)], # Trace 5
... ],
... ]
>>> normalise_row_attrs([["A"], ["B"]], densities)
[['A', 'A', 'A'], ['B', 'B']]
>>> normalise_row_attrs([["A"], ["B", "C"]], densities)
[['A', 'A', 'A'], ['B', 'C']]
>>> normalise_row_attrs([["A", "A", "A"], ["B", "B"]], densities)
[['A', 'A', 'A'], ['B', 'B']]
>>> normalise_row_attrs([["A", "B"], ["C"]], densities)
Traceback (most recent call last):
...
ValueError: Mismatch between number of traces (3) and number of attrs (2) for row 0.
"""
norm_attrs = []
for i, (row, row_attr) in enumerate(zip(densities, attrs)):
n_traces = len(row)
n_attrs = len(row_attr)
if n_traces != n_attrs:
if n_attrs != 1:
raise ValueError(
f"Mismatch between number of traces ({n_traces}) "
f"and number of attrs ({n_attrs}) for row {i}."
)
row_attr = list(row_attr) * n_traces # noqa: PLW2901
norm_attrs.append(row_attr)
return norm_attrs


_KT = TypeVar("_KT") # Mapping key type
_VT = TypeVar("_VT") # Mapping value type

Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@ def test_shallow_samples() -> None:


def test_shallow_labels() -> None:
shallow_labels = ["trace 1", "trace 2", "trace 3"]
shallow_labels = ["trace 1", "trace 2"]
assert (
ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], labels=shallow_labels) ==
ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], labels=nest_shallow_collection(shallow_labels))
) # fmt: skip


def test_y_labels_dedup() -> None:
assert (
ridgeplot(samples=[[[1, 2, 3], [4, 5, 6]]], labels=["a"]) ==
ridgeplot(samples=[[[1, 2, 3], [4, 5, 6]]], labels=[["a", "a"]])
) # fmt: skip


def test_deprecated_colormode_index() -> None:
with pytest.warns(
DeprecationWarning,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ commands =
tests-unit: diff-cover {env:_DIFFCOVER_DFLT_ARGS}

tests-e2e: pytest tests/e2e --cov=cicd_utils/cicd/ridgeplot_examples {env:_COV_REPORT_XML} --cov-fail-under=75 {posargs:}
tests-e2e: diff-cover {env:_DIFFCOVER_DFLT_ARGS}
tests-e2e: diff-cover {env:_DIFFCOVER_DFLT_ARGS} --fail-under=95

tests-cicd_utils: rm -f {env:_COV_FILES_RM}
tests-cicd_utils: pytest tests/cicd_utils --cov=cicd_utils {env:_COV_REPORT_XML} --cov-fail-under=50 {posargs:}
Expand Down

0 comments on commit 26c8e5e

Please sign in to comment.