Skip to content

Commit

Permalink
Merge pull request #75 from e10v/dev
Browse files Browse the repository at this point in the history
Auto check and default config for n_obs
  • Loading branch information
e10v authored Jul 13, 2024
2 parents 4fad1d3 + 405eb87 commit e833c56
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 47 deletions.
14 changes: 11 additions & 3 deletions src/tea_tasting/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Global configuration."""
# ruff: noqa: PLR0913

from __future__ import annotations

Expand All @@ -9,7 +10,7 @@


if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Iterator, Sequence
from typing import Any, Literal


Expand All @@ -18,6 +19,7 @@
"alternative": "two-sided",
"confidence_level": 0.95,
"equal_var": False,
"n_obs": None,
"n_resamples": 10_000,
"power": 0.8,
"ratio": 1,
Expand Down Expand Up @@ -55,6 +57,7 @@ def set_config(
alternative: Literal["two-sided", "greater", "less"] | None = None,
confidence_level: float | None = None,
equal_var: bool | None = None,
n_obs: int | Sequence[int] | None = None,
n_resamples: int | None = None,
power: float | None = None,
ratio: float | int | None = None,
Expand All @@ -71,6 +74,8 @@ def set_config(
equal_var: Defines whether equal variance is assumed. If `True`,
pooled variance is used for the calculation of the standard error
of the difference between two means. Default is `False`.
n_obs: Number of observations in the control and in the treatment together.
Default is `None`.
n_resamples: The number of resamples performed to form the bootstrap
distribution of a statistic. Default is `10_000`.
power: Statistical power. Default is 0.8.
Expand Down Expand Up @@ -119,12 +124,13 @@ def config_context(
alternative: Literal["two-sided", "greater", "less"] | None = None,
confidence_level: float | None = None,
equal_var: bool | None = None,
n_obs: int | Sequence[int] | None = None,
n_resamples: int | None = None,
power: float | None = None,
ratio: float | int | None = None,
use_t: bool | None = None,
**kwargs: Any,
) -> Generator[None, Any, None]:
) -> Iterator[Any]:
"""A context manager that temporarily modifies the global configuration.
Args:
Expand All @@ -135,6 +141,8 @@ def config_context(
equal_var: Defines whether equal variance is assumed. If `True`,
pooled variance is used for the calculation of the standard error
of the difference between two means. Default is `False`.
n_obs: Number of observations in the control and in the treatment together.
Default is `None`.
n_resamples: The number of resamples performed to form the bootstrap
distribution of a statistic. Default is `10_000`.
power: Statistical power. Default is 0.8.
Expand Down Expand Up @@ -176,4 +184,4 @@ def config_context(
try:
yield
finally:
set_config(**old_config)
_global_config.update(**old_config)
78 changes: 38 additions & 40 deletions src/tea_tasting/metrics/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal
from typing import Literal, TypeVar

from tea_tasting.metrics.base import PowerParameter


N = TypeVar("N", bound=float | int | None)


MAX_ITER = 100


Expand Down Expand Up @@ -227,39 +230,35 @@ def __init__( # noqa: PLR0913
"Both `effect_size` and `rel_effect_size` are not `None`. "
"Only one of them should be defined.",
)
if effect_size is None:
self.effect_size = effect_size
else:
if not isinstance(effect_size, Sequence):
effect_size = (effect_size,)
self.effect_size = tuple(
if isinstance(effect_size, Sequence):
for x in effect_size:
tea_tasting.utils.check_scalar(
x, "effect_size", typ=float | int,
gt=float("-inf"), lt=float("inf"), ne=0,
)
for x in effect_size
elif effect_size is not None:
tea_tasting.utils.check_scalar(
effect_size, "effect_size", typ=float | int,
gt=float("-inf"), lt=float("inf"), ne=0,
)
if rel_effect_size is None:
self.rel_effect_size = None
else:
if not isinstance(rel_effect_size, Sequence):
rel_effect_size = (rel_effect_size,)
self.rel_effect_size = tuple(
self.effect_size = effect_size
if isinstance(rel_effect_size, Sequence):
for x in rel_effect_size:
tea_tasting.utils.check_scalar(
x, "rel_effect_size", typ=float | int,
gt=float("-inf"), lt=float("inf"), ne=0,
)
for x in rel_effect_size
)
if n_obs is None:
self.n_obs = None
else:
if not isinstance(n_obs, Sequence):
n_obs = (n_obs,)
self.n_obs = tuple(
tea_tasting.utils.check_scalar(x, "n_obs", typ=int, gt=1)
for x in n_obs
elif rel_effect_size is not None:
tea_tasting.utils.check_scalar(
rel_effect_size, "rel_effect_size", typ=float | int,
gt=float("-inf"), lt=float("inf"), ne=0,
)
self.rel_effect_size = rel_effect_size
self.n_obs = (
tea_tasting.utils.auto_check(n_obs, "n_obs")
if n_obs is not None
else tea_tasting.config.get_config("n_obs")
)


@property
Expand Down Expand Up @@ -391,9 +390,9 @@ def _validate_power_parameters(
parameter: PowerParameter,
) -> tuple[
float | None, # power
tuple[float | int | None, ...], # effect_size
tuple[float | None, ...], # rel_effect_size
tuple[int | None, ...], # n_obs
Sequence[float | int | None], # effect_size
Sequence[float | None], # rel_effect_size
Sequence[int | None], # n_obs
]:
n_obs = None
effect_size = None
Expand All @@ -410,14 +409,14 @@ def _validate_power_parameters(
self.effect_size if self.rel_effect_size is None
else tuple(
rel_effect_size * metric_mean
for rel_effect_size in self.rel_effect_size
for rel_effect_size in _to_seq(self.rel_effect_size)
)
)
rel_effect_size = (
self.rel_effect_size if self.effect_size is None
else tuple(
effect_size / metric_mean
for effect_size in self.effect_size
for effect_size in _to_seq(self.effect_size)
)
)

Expand All @@ -427,14 +426,7 @@ def _validate_power_parameters(
if parameter in {"effect_size", "rel_effect_size", "n_obs"}:
power = self.power

if effect_size is None:
effect_size = (None,)
if rel_effect_size is None:
rel_effect_size = (None,)
if n_obs is None:
n_obs = (None,)

return power, effect_size, rel_effect_size, n_obs
return power, _to_seq(effect_size), _to_seq(rel_effect_size), _to_seq(n_obs)


def _covariate_coef(self, aggr: tea_tasting.aggr.Aggregates) -> float:
Expand Down Expand Up @@ -685,6 +677,12 @@ def _find_boundary(
return b


def _to_seq(x: N | Sequence[N]) -> Sequence[N]:
if isinstance(x, Sequence):
return x
return (x,)


class Mean(RatioOfMeans): # noqa: D101
def __init__( # noqa: PLR0913
self,
Expand All @@ -698,9 +696,9 @@ def __init__( # noqa: PLR0913
alpha: float | None = None,
ratio: float | int | None = None,
power: float | None = None,
effect_size: float | int | None = None,
rel_effect_size: float | None = None,
n_obs: int | None = None,
effect_size: float | int | Sequence[float | int] | None = None,
rel_effect_size: float | Sequence[float] | None = None,
n_obs: int | Sequence[int] | None = None,
) -> None:
"""Metric for the analysis of means.
Expand Down
14 changes: 11 additions & 3 deletions src/tea_tasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import abc
from collections.abc import Sequence
import inspect
import locale
import math
Expand All @@ -13,7 +14,7 @@


if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Iterator
from typing import Any, Literal, TypeVar

R = TypeVar("R")
Expand Down Expand Up @@ -65,7 +66,7 @@ def check_scalar( # noqa: PLR0913
return value


def auto_check(value: R, name: str) -> R:
def auto_check(value: R, name: str) -> R: # noqa: C901, PLR0912
"""Automatically check a parameter's type and value based on its name.
Args:
Expand All @@ -85,9 +86,16 @@ def auto_check(value: R, name: str) -> R:
check_scalar(value, name, typ=bool)
elif name == "equal_var":
check_scalar(value, name, typ=bool)
elif name == "n_obs":
check_scalar(value, name, typ=int | Sequence | None)
if isinstance(value, int):
check_scalar(value, name, gt=1)
if isinstance(value, Sequence):
for val in value:
check_scalar(val, name, typ=int, gt=1)
elif name == "n_resamples":
check_scalar(value, name, typ=int, gt=0)
if name == "power":
elif name == "power":
check_scalar(value, name, typ=float, gt=0, lt=1)
elif name == "ratio":
check_scalar(value, name, typ=float | int, gt=0)
Expand Down
6 changes: 5 additions & 1 deletion tests/metrics/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ def test_ratio_of_means_init_custom():
assert metric.ratio == 0.5
assert metric.power == 0.75
assert metric.effect_size is None
assert metric.rel_effect_size == (0.08,)
assert metric.rel_effect_size == 0.08
assert metric.n_obs == (5_000, 10_000)
metric = tea_tasting.metrics.mean.RatioOfMeans("a", effect_size=(1, 0.2))
assert metric.effect_size == (1, 0.2)

def test_ratio_of_means_init_config():
with tea_tasting.config.config_context(
Expand All @@ -137,6 +139,7 @@ def test_ratio_of_means_init_config():
alpha=0.1,
ratio=0.5,
power=0.75,
n_obs=(5_000, 10_000),
):
metric = tea_tasting.metrics.mean.RatioOfMeans("a")
assert metric.alternative == "greater"
Expand All @@ -146,6 +149,7 @@ def test_ratio_of_means_init_config():
assert metric.alpha == 0.1
assert metric.ratio == 0.5
assert metric.power == 0.75
assert metric.n_obs == (5_000, 10_000)


def test_ratio_of_means_aggr_cols():
Expand Down
13 changes: 13 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def test_auto_check_equal_var():
with pytest.raises(TypeError):
tea_tasting.utils.auto_check(0, "equal_var")

def test_auto_check_n_obs():
assert tea_tasting.utils.auto_check(2, "n_obs") == 2
assert tea_tasting.utils.auto_check((2, 3), "n_obs") == (2, 3)
assert tea_tasting.utils.auto_check(None, "n_obs") is None
with pytest.raises(TypeError):
tea_tasting.utils.auto_check(0.5, "n_obs")
with pytest.raises(TypeError):
tea_tasting.utils.auto_check((0.5, 2), "n_obs")
with pytest.raises(ValueError, match="must be >"):
tea_tasting.utils.auto_check(1, "n_obs")
with pytest.raises(ValueError, match="must be >"):
tea_tasting.utils.auto_check((1, 2), "n_obs")

def test_auto_check_n_resamples():
assert tea_tasting.utils.auto_check(1, "n_resamples") == 1
with pytest.raises(ValueError, match="must be >"):
Expand Down

0 comments on commit e833c56

Please sign in to comment.