Skip to content

Commit

Permalink
Move some type coercions to the ridgeplot function and drop support…
Browse files Browse the repository at this point in the history
… for some deprecations
  • Loading branch information
tpvasconcelos committed Oct 14, 2024
1 parent 0587745 commit 0b7ebc0
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 71 deletions.
32 changes: 9 additions & 23 deletions src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from plotly import graph_objects as go

from ridgeplot._colors import apply_alpha, get_color, get_colorscale, validate_colorscale
from ridgeplot._colors import apply_alpha, get_color
from ridgeplot._types import CollectionL1, CollectionL2
from ridgeplot._utils import normalise_min_max

Expand Down Expand Up @@ -151,10 +151,10 @@ class RidgePlotFigureFactory:
def __init__(
self,
densities: Densities,
colorscale: str | ColorScale,
colorscale: ColorScale,
coloralpha: float | None,
colormode: Colormode,
labels: LabelsArray | None,
labels: LabelsArray,
linewidth: float,
spacing: float,
show_yticklabels: bool,
Expand All @@ -163,31 +163,17 @@ def __init__(
# ==============================================================
# --- Get clean and validated input arguments
# ==============================================================
n_rows = len(densities)
n_traces = sum(len(row) for row in densities)

if isinstance(colorscale, str):
colorscale = get_colorscale(name=colorscale)
validate_colorscale(colorscale)

if colormode not in self.colormode_maps:
raise ValueError(
f"The colormode argument should be one of "
f"{tuple(self.colormode_maps.keys())}, got {colormode} instead."
)

if coloralpha is not None:
coloralpha = float(coloralpha)

if labels is None:
ids = iter(range(1, n_traces + 1))
labels = [[f"Trace {next(ids)}" for _ in row] for row in densities]

self.densities: Densities = densities
self.colorscale: ColorScale = colorscale
self.coloralpha: float | None = coloralpha
self.densities = densities
self.colorscale = colorscale
self.coloralpha = float(coloralpha) if coloralpha is not None else None
self.colormode = colormode
self.labels: LabelsArray = labels
self.labels = labels
self.linewidth: float = float(linewidth)
self.spacing: float = float(spacing)
self.show_yticklabels: bool = bool(show_yticklabels)
Expand All @@ -196,8 +182,8 @@ def __init__(
# ==============================================================
# --- Other instance variables
# ==============================================================
self.n_rows: int = n_rows
self.n_traces: int = n_traces
self.n_rows: int = len(densities)
self.n_traces: int = sum(len(row) for row in densities)
self.x_min, self.x_max, _, self.y_max = get_xy_extrema(densities=self.densities)
self.fig: go.Figure = go.Figure()
self.colors: ColorsArray = self.pre_compute_colors()
Expand Down
101 changes: 59 additions & 42 deletions src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, cast

from ridgeplot._colors import get_colorscale, validate_colorscale
from ridgeplot._figure_factory import LabelsArray, RidgePlotFigureFactory, ShallowLabelsArray
from ridgeplot._kde import estimate_densities
from ridgeplot._missing import MISSING, MissingType
Expand All @@ -26,6 +26,39 @@
from ridgeplot._kde import KDEBandwidth, KDEPoints


def _normalise_densities(
samples: Samples | ShallowSamples | None,
densities: Densities | ShallowDensities | None,
kernel: str,
bandwidth: KDEBandwidth,
kde_points: KDEPoints,
) -> Densities:
has_samples = samples is not None
has_densities = densities is not None
if has_samples and has_densities:
raise ValueError("You may not specify both `samples` and `densities` arguments!")
if not has_samples and not has_densities:
raise ValueError("You must specify either `samples` or `densities`")
if has_densities:
if is_shallow_densities(densities):
densities = cast(ShallowDensities, densities)
densities = nest_shallow_collection(densities)
densities = cast(Densities, densities)
else:
if is_shallow_samples(samples):
samples = cast(ShallowSamples, samples)
samples = nest_shallow_collection(samples)
samples = cast(Samples, samples)
# Convert samples to densities
densities = estimate_densities(
samples=samples,
points=kde_points,
kernel=kernel,
bandwidth=bandwidth,
)
return densities


def ridgeplot(
samples: Samples | ShallowSamples | None = None,
densities: Densities | ShallowDensities | None = None,
Expand Down Expand Up @@ -204,56 +237,40 @@ def ridgeplot(
if neither of them is specified. i.e. you may only specify one of them.
"""
has_samples = samples is not None
has_densities = densities is not None
if has_samples and has_densities:
raise ValueError("You may not specify both `samples` and `densities` arguments!")
if not has_samples and not has_densities:
raise ValueError("You must specify either `samples` or `densities`")
densities = _normalise_densities(
samples=samples,
densities=densities,
kernel=kernel,
bandwidth=bandwidth,
kde_points=kde_points,
)
del samples, kernel, bandwidth, kde_points

if has_densities:
if is_shallow_densities(densities):
densities = cast(ShallowDensities, densities)
densities = nest_shallow_collection(densities)
densities = cast(Densities, densities)
else:
if is_shallow_samples(samples):
samples = cast(ShallowSamples, samples)
samples = nest_shallow_collection(samples)
samples = cast(Samples, samples)
# Convert samples to densities
densities = estimate_densities(
samples=samples,
points=kde_points,
kernel=kernel,
bandwidth=bandwidth,
# n_rows = len(densities)
n_traces = sum(len(row) for row in densities)

if isinstance(colorscale, str):
colorscale = get_colorscale(name=colorscale)
validate_colorscale(colorscale)

if colormode == "index": # type: ignore[comparison-overlap]
raise ValueError(
"HINT: The colormode='index' value has been deprecated "
"in the past in favor of colormode='row-index'."
)

if is_flat_str_collection(labels):
labels = cast(ShallowLabelsArray, labels)
labels = cast(LabelsArray, nest_shallow_collection(labels))

if colormode == "index": # type: ignore[comparison-overlap]
warnings.warn( # type: ignore[unreachable]
"The colormode='index' value has been deprecated in favor of "
"colormode='row-index', which provides the same functionality but "
"is more explicit and allows to distinguish between the "
"'row-index' and 'trace-index' modes. Support for the "
"deprecated value will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
colormode = "row-index"
if labels is None:
ids = iter(range(1, n_traces + 1))
labels = [[f"Trace {next(ids)}" for _ in row] for row in densities]

if show_annotations is not MISSING:
warnings.warn(
"The show_annotations argument has been deprecated in favor of "
"show_yticklabels. Support for the deprecated argument will be "
"removed in a future version.",
DeprecationWarning,
stacklevel=2,
raise TypeError(
"HINT: The show_annotations argument has been deprecated "
"in the past in favor of show_yticklabels.",
)
show_yticklabels = show_annotations

ridgeplot_figure_factory = RidgePlotFigureFactory(
densities=densities,
Expand Down
18 changes: 12 additions & 6 deletions tests/unit/test_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def test_shallow_labels() -> None:


def test_deprecated_colormode_index() -> None:
with pytest.warns(
DeprecationWarning,
match="The colormode='index' value has been deprecated in favor of colormode='row-index'",
with pytest.raises(
ValueError,
match=(
"The colormode='index' value has been deprecated "
"in the past in favor of colormode='row-index'"
),
):
ridgeplot(
samples=[[1, 2, 3], [1, 2, 3]],
Expand All @@ -59,8 +62,11 @@ def test_deprecated_colormode_index() -> None:


def test_deprecated_show_annotations_is_not_missing() -> None:
with pytest.warns(
DeprecationWarning,
match="The show_annotations argument has been deprecated in favor of show_yticklabels",
with pytest.raises(
TypeError,
match=(
"The show_annotations argument has been deprecated "
"in the past in favor of show_yticklabels"
),
):
ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], show_annotations=True)

0 comments on commit 0b7ebc0

Please sign in to comment.