Skip to content

Commit

Permalink
Make pandas optional
Browse files Browse the repository at this point in the history
Also:
- Fix cli extra - was missing matlotlib/seaborn
- Add 'all' extra

(cherry picked from commit 3feeed6)
  • Loading branch information
rsundqvist committed May 9, 2024
1 parent e66ae1c commit e61e917
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 38 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ mtimeit = { callable = "rics.performance.cli:main", extras = ["cli"] }
[tool.poetry.dependencies]
python = ">=3.11"

pandas = ">=2.0.3"
pandas = { version = ">=2.0.3", optional = true }

# cli
click = { version = "*", optional = true }
Expand All @@ -44,7 +44,8 @@ matplotlib = { version = "*", optional = true }
seaborn = { version = "*", optional = true }

[tool.poetry.extras]
cli = ["click"]
all = ["pandas", "click", "matplotlib", "seaborn"]
cli = ["pandas", "click", "matplotlib", "seaborn"]
plotting = ["matplotlib", "seaborn"]

[tool.poetry.group.manual-extras.dependencies]
Expand Down
48 changes: 35 additions & 13 deletions src/rics/performance/_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections.abc import Hashable, Iterable
from typing import Any, Literal, TypeGuard, cast, get_args

import pandas as pd
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Literal, TypeGuard, Union, cast, get_args

from .types import ResultsDict

if TYPE_CHECKING:
import pandas


def to_dataframe(run_results: ResultsDict, names: Iterable[str] = ()) -> pd.DataFrame:
def to_dataframe(run_results: ResultsDict, names: Iterable[str] = ()) -> "pandas.DataFrame":
"""Create a DataFrame from performance run output, adding derived values.
Args:
Expand All @@ -20,6 +22,9 @@ def to_dataframe(run_results: ResultsDict, names: Iterable[str] = ()) -> pd.Data
The `run_result` input wrapped in a DataFrame.
"""
with _import_context("pandas"):
from pandas import DataFrame, concat

names = tuple(names)
frames = []
for candidate_label, candidate_results in run_results.items():
Expand All @@ -39,10 +44,10 @@ def to_dataframe(run_results: ResultsDict, names: Iterable[str] = ()) -> pd.Data
raise ValueError(msg)
data[name] = label_part

frame = pd.DataFrame.from_dict(data, orient="columns")
frame = DataFrame.from_dict(data, orient="columns")
frames.append(frame)

df = pd.concat(frames, ignore_index=True)
df = concat(frames, ignore_index=True)
df["Time [ms]"] = df["Time [s]"] * 1000
df["Time [μs]"] = df["Time [ms]"] * 1000
df["Time [ns]"] = df["Time [μs]"] * 1000
Expand All @@ -69,7 +74,7 @@ def _has_names(data_label: Hashable, *, names: tuple[str, ...]) -> TypeGuard[tup
return True


def get_best(run_results: ResultsDict | pd.DataFrame, per_candidate: bool = False) -> pd.DataFrame:
def get_best(run_results: Union[ResultsDict, "pandas.DataFrame"], per_candidate: bool = False) -> "pandas.DataFrame":
"""Get a summarized view of the best run results for each candidate/data pair.
Args:
Expand All @@ -81,15 +86,18 @@ def get_best(run_results: ResultsDict | pd.DataFrame, per_candidate: bool = Fals
The best (lowest) times for each candidate/data pair.
"""
df = run_results if isinstance(run_results, pd.DataFrame) else to_dataframe(run_results)
with _import_context("seaborn"):
from pandas import DataFrame

df = run_results if isinstance(run_results, DataFrame) else to_dataframe(run_results)
return df.sort_values("Time [s]").groupby(["Candidate", "Test data"] if per_candidate else "Test data").head(1)


Unit = Literal["s", "ms", "μs", "us", "ns"]


def plot_run(
run_results: ResultsDict | pd.DataFrame,
run_results: Union[ResultsDict, "pandas.DataFrame"],
x: Literal["candidate", "data"] | None = None,
unit: Unit | None = None,
**kwargs: Any,
Expand All @@ -113,8 +121,9 @@ def plot_run(
"""
import warnings

import matplotlib.pyplot as plt
from seaborn import barplot, move_legend
with _import_context("seaborn"): # Seaborn installs matplotlib as well
import matplotlib.pyplot as plt
from seaborn import barplot, move_legend

data = to_dataframe(run_results) if isinstance(run_results, dict) else run_results.copy()
data[["Test data", "Candidate"]] = data[["Test data", "Candidate"]].astype("category")
Expand Down Expand Up @@ -154,12 +163,12 @@ def plot_run(
left.get_legend().remove()


def _smaller_as_hue(data: pd.DataFrame) -> tuple[str, str]:
def _smaller_as_hue(data: "pandas.DataFrame") -> tuple[str, str]:
unique = data.nunique()
return ("Test data", "Candidate") if unique["Test data"] < unique["Candidate"] else ("Candidate", "Test data")


def _unit_from_data(df: pd.DataFrame) -> Unit:
def _unit_from_data(df: "pandas.DataFrame") -> Unit:
"""Pick the unit with the most "human" scale; whole numbers around one hundred."""
from numpy import log10

Expand All @@ -175,3 +184,16 @@ def _unit_from_data(df: pd.DataFrame) -> Unit:

assert unit in get_args(Unit) # noqa: S101
return cast(Unit, unit)


@contextmanager
def _import_context(name: str): # type: ignore[no-untyped-def] # noqa: ANN202
try:
yield
except ModuleNotFoundError as e:
msg = (
f"Missing optional dependency '{name}'. Install this package manually, or run:\n"
" pip install rics[all]"
"\nto get all optional dependencies."
)
raise ImportError(msg) from e
20 changes: 13 additions & 7 deletions src/rics/performance/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any

import pandas as pd
from typing import TYPE_CHECKING, Any, Union

from ._multi_case_timer import CandidateMethodArg, MultiCaseTimer, TestDataArg
from ._util import plot_run, to_dataframe
from .types import DataType
from .types import DataType, ResultsDict

if TYPE_CHECKING:
import pandas


def run_multivariate_test(
Expand All @@ -13,9 +14,14 @@ def run_multivariate_test(
time_per_candidate: float = 6.0,
plot: bool = True,
**figure_kwargs: Any,
) -> pd.DataFrame:
) -> Union["pandas.DataFrame", ResultsDict]:
"""Run performance tests for multiple candidate methods on collections of test data.
.. note::
Returns a :attr:`~rics.performance.types.ResultsDict` if :mod:`pandas` is not
installed. Plotting requires :func:`seaborn <seaborn.barplot>`.
This is a convenience method which combines :meth:`MultiCaseTimer.run() <rics.performance.MultiCaseTimer.run>`,
:meth:`~rics.performance.to_dataframe` and -- if plotting is enabled -- :meth:`~rics.performance.plot_run`. For full
functionally these methods should be use directly.
Expand All @@ -25,10 +31,10 @@ def run_multivariate_test(
test_data: A single datum, or a dict ``{label: data}`` to evaluate candidates on.
time_per_candidate: Desired runtime for each repetition per candidate label.
plot: If ``True``, plot a figure using :meth:`~rics.performance.plot_run`.
**figure_kwargs: Keyword arguments for the :seaborn.barplot`. Ignored if ``plot=False``.
**figure_kwargs: Keyword arguments for the :func:`seaborn.barplot`. Ignored if ``plot=False``.
Returns:
A long-format DataFrame of results.
A long-format :class:`pandas.DataFrame` of results.
Raises:
ModuleNotFoundError: If Seaborn isn't installed and ``plot=True``.
Expand Down
9 changes: 6 additions & 3 deletions src/rics/performance/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
import os
import sys
from pathlib import Path as _Path
from typing import TYPE_CHECKING
from typing import Any as _Any

import click
import pandas as pd

from .._just_the_way_i_like_it import configure_stuff as _configure_stuff
from ._util import get_best as _get_best
from ._wrapper import run_multivariate_test as _run

LARGE_RESULT_ROW_LIMIT = 10000

if TYPE_CHECKING:
import pandas


def _get_test_data() -> _Any:
try:
Expand Down Expand Up @@ -166,7 +169,7 @@ def main(time_per_candidate: float, name: str, create: bool, per_candidate: bool
_save_report(name_path, result)


def _print_best_result(per_candidate: bool, result: pd.DataFrame, title: str) -> None:
def _print_best_result(per_candidate: bool, result: "pandas.DataFrame", title: str) -> None:
click.secho("=" * 80, fg="green")
click.secho("| {f' Best Times ':^74} |", fg="green")
click.secho(f"| {f' {title!r} ':^74} |", fg="green")
Expand All @@ -175,7 +178,7 @@ def _print_best_result(per_candidate: bool, result: pd.DataFrame, title: str) ->
click.secho("=" * 80, fg="green")


def _save_report(name_path: _Path, result: pd.DataFrame) -> None:
def _save_report(name_path: _Path, result: "pandas.DataFrame") -> None:
performance_report_path = name_path.with_suffix(".csv")

if len(result) > LARGE_RESULT_ROW_LIMIT:
Expand Down
18 changes: 10 additions & 8 deletions src/rics/performance/plot/_params.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar, Literal, Self
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, Union

import numpy as np
import pandas as pd

from ...collections.dicts import compute_if_absent
from ..types import ResultsDict
from .types import Candidate, FuncOrData, Kind, TestData, Unit

if TYPE_CHECKING:
import pandas

FUNC: Candidate = "Candidate"
DATA: TestData = "Test data"


@dataclass(frozen=True, kw_only=True)
class CatplotParams:
data: pd.DataFrame = field(repr=False)
data: "pandas.DataFrame" = field(repr=False)
x: FuncOrData
y: str
hue: FuncOrData
Expand Down Expand Up @@ -43,7 +45,7 @@ def __post_init__(self) -> None:
@classmethod
def make(
cls,
run_results: ResultsDict | pd.DataFrame,
run_results: Union[ResultsDict, "pandas.DataFrame"],
*,
x: Literal["candidate", "data"] | None = None,
unit: Unit | None = None,
Expand Down Expand Up @@ -144,7 +146,7 @@ def format_label(label: tuple[Hashable]) -> str:
kwargs.update(updates)


def _make_df(run_results: ResultsDict | pd.DataFrame) -> pd.DataFrame:
def _make_df(run_results: Union[ResultsDict, "pandas.DataFrame"]) -> "pandas.DataFrame":
from rics.performance import to_dataframe

df = to_dataframe(run_results) if isinstance(run_results, dict) else run_results.copy()
Expand All @@ -153,15 +155,15 @@ def _make_df(run_results: ResultsDict | pd.DataFrame) -> pd.DataFrame:
return df


def _is_data_x(x_col: Literal["candidate", "data"] | None, *, df: pd.DataFrame) -> bool:
def _is_data_x(x_col: Literal["candidate", "data"] | None, *, df: "pandas.DataFrame") -> bool:
if x_col is None:
n_data, n_func = df[[DATA, FUNC]].nunique()
return n_data > n_func # type: ignore[no-any-return]
else:
return x_col.lower().startswith("d")


def _resolve_y(unit: Unit | None, *, df: pd.DataFrame) -> str:
def _resolve_y(unit: Unit | None, *, df: "pandas.DataFrame") -> str:
if unit is None:
return _compute_nice_y(df)

Expand All @@ -175,7 +177,7 @@ def _resolve_y(unit: Unit | None, *, df: pd.DataFrame) -> str:
return y


def _compute_nice_y(df: pd.DataFrame) -> str:
def _compute_nice_y(df: "pandas.DataFrame") -> str:
"""Pick the unit with the most "human" scale; whole numbers around one hundred."""
from numpy import log10

Expand Down
8 changes: 5 additions & 3 deletions src/rics/performance/plot/_plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from pathlib import Path
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal, Union

import pandas as pd
from seaborn import FacetGrid, catplot

from rics.performance.types import ResultsDict
Expand All @@ -11,9 +10,12 @@
from ._postprocessors import make_postprocessors
from .types import Unit

if TYPE_CHECKING:
import pandas


def plot(
run_results: ResultsDict | pd.DataFrame,
run_results: Union[ResultsDict, "pandas.DataFrame"],
x: Literal["candidate", "data"] | None = None,
*,
unit: Unit | None = None,
Expand Down
14 changes: 14 additions & 0 deletions tests/performance/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from rics.performance._util import _import_context


def test_import_missing():
with pytest.raises(ImportError, match=r"rics\[all\]"):
_import_missing()


def _import_missing():
with _import_context("package-name"):
import does_not_exist # type: ignore[import-not-found]

does_not_exist.dont_delete_me()
4 changes: 2 additions & 2 deletions tests/performance/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from pathlib import Path

import pandas as pd
# import pandas as pd
import pytest
from click.testing import CliRunner
from rics.performance import MultiCaseTimer, cli, run_multivariate_test
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_cli(monkeypatch, with_all):
assert Path(tmp).joinpath("unit-test.png").is_file()
csv = Path(tmp).joinpath("unit-test.csv")
assert csv.is_file()
verify(pd.read_csv(csv))
# verify(pd.read_csv(csv))


def test_run_multivariate_test(monkeypatch):
Expand Down

0 comments on commit e61e917

Please sign in to comment.