Skip to content

Commit

Permalink
feat: create seaborn wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
tomjholland committed Jan 2, 2025
1 parent d3d6b9a commit 5b9eb18
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
55 changes: 55 additions & 0 deletions pyprobe/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,61 @@
if TYPE_CHECKING:
from pyprobe.result import Result

from functools import wraps
from typing import Any, Callable

import seaborn as _sns


def _convert_data(result_obj: "Result") -> Any:
return result_obj.data.to_pandas()


def _create_seaborn_wrapper() -> Any:
"""Create wrapped version of seaborn module."""
wrapped_sns = type("SeabornWrapper", (), {})()

def wrap_function(func: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap a seaborn function.
Args:
func (Callable): The function to wrap.
"""

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""The wrapper function.
Modifies the 'data' argument to seaborn functions to be compatible with
PyProBE Result objects.
Args:
*args: The positional arguments.
**kwargs: The keyword arguments.
Returns:
The result of the wrapped function.
"""
if "data" in kwargs:
kwargs["data"] = _convert_data(kwargs["data"])
return func(*args, **kwargs)

return wrapper

# Copy all seaborn attributes
for attr_name in dir(_sns):
if not attr_name.startswith("_"):
attr = getattr(_sns, attr_name)
if callable(attr):
setattr(wrapped_sns, attr_name, wrap_function(attr))
else:
setattr(wrapped_sns, attr_name, attr)

return wrapped_sns


seaborn = _create_seaborn_wrapper()


class Plot:
"""A class for plotting result objects with plotly.
Expand Down
56 changes: 56 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,69 @@
import plotly.graph_objects as go
import polars as pl
import pytest
import seaborn as _sns
from plotly.express.colors import sample_colorscale
from sklearn.preprocessing import minmax_scale

from pyprobe import plot
from pyprobe.plot import Plot
from pyprobe.result import Result


def test_seaborn_wrapper_creation():
"""Test basic seaborn wrapper creation."""
wrapper = plot._create_seaborn_wrapper()
assert wrapper is not None
assert isinstance(wrapper, object)


def test_seaborn_wrapper_data_conversion(mocker):
"""Test that wrapped functions convert data correctly."""
result = Result(
base_dataframe=pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}),
info={},
column_definitions={"x": "int", "y": "int"},
)
data = result.data.to_pandas()
pyprobe_seaborn_plot = plot.seaborn.lineplot(data=result, x="x", y="y")
seaborn_lineplot = _sns.lineplot(data=data, x="x", y="y")
assert pyprobe_seaborn_plot == seaborn_lineplot


def test_seaborn_wrapper_function_call():
"""Test that wrapped functions produce same output."""
wrapper = plot._create_seaborn_wrapper()

assert wrapper.set_theme() == _sns.set_theme()

colors1 = wrapper.color_palette()
colors2 = _sns.color_palette()
assert colors1 == colors2

# Test with specific parameters
palette1 = wrapper.color_palette("husl", 8)
palette2 = _sns.color_palette("husl", 8)
assert palette1 == palette2


def test_seaborn_wrapper_function_properties():
"""Test that wrapped functions maintain original properties."""
wrapper = plot._create_seaborn_wrapper()
original_func = _sns.lineplot
wrapped_func = wrapper.lineplot

assert wrapped_func.__name__ == original_func.__name__
assert wrapped_func.__doc__ == original_func.__doc__


def test_seaborn_wrapper_complete_coverage():
"""Test that all public seaborn attributes are wrapped."""
wrapper = plot._create_seaborn_wrapper()
sns_attrs = {attr for attr in dir(_sns) if not attr.startswith("_")}
wrapper_attrs = {attr for attr in dir(wrapper) if not attr.startswith("_")}
assert sns_attrs == wrapper_attrs


@pytest.fixture
def Plot_fixture():
"""Return a Plot instance."""
Expand Down

0 comments on commit 5b9eb18

Please sign in to comment.