diff --git a/src/tea_tasting/config.py b/src/tea_tasting/config.py index 4105f53..7415df2 100644 --- a/src/tea_tasting/config.py +++ b/src/tea_tasting/config.py @@ -1,4 +1,5 @@ """Global configuration.""" +# ruff: noqa: PLR0913 from __future__ import annotations @@ -9,7 +10,7 @@ if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Iterator, Sequence from typing import Any, Literal @@ -18,6 +19,7 @@ "alternative": "two-sided", "confidence_level": 0.95, "equal_var": False, + "n_obs": None, "n_resamples": 10_000, "power": 0.8, "ratio": 1, @@ -55,6 +57,7 @@ def set_config( alternative: Literal["two-sided", "greater", "less"] | None = None, confidence_level: float | None = None, equal_var: bool | None = None, + n_obs: int | Sequence[int] | None = None, n_resamples: int | None = None, power: float | None = None, ratio: float | int | None = None, @@ -71,6 +74,8 @@ def set_config( equal_var: Defines whether equal variance is assumed. If `True`, pooled variance is used for the calculation of the standard error of the difference between two means. Default is `False`. + n_obs: Number of observations in the control and in the treatment together. + Default is `None`. n_resamples: The number of resamples performed to form the bootstrap distribution of a statistic. Default is `10_000`. power: Statistical power. Default is 0.8. @@ -119,12 +124,13 @@ def config_context( alternative: Literal["two-sided", "greater", "less"] | None = None, confidence_level: float | None = None, equal_var: bool | None = None, + n_obs: int | Sequence[int] | None = None, n_resamples: int | None = None, power: float | None = None, ratio: float | int | None = None, use_t: bool | None = None, **kwargs: Any, -) -> Generator[None, Any, None]: +) -> Iterator[Any]: """A context manager that temporarily modifies the global configuration. Args: @@ -135,6 +141,8 @@ def config_context( equal_var: Defines whether equal variance is assumed. If `True`, pooled variance is used for the calculation of the standard error of the difference between two means. Default is `False`. + n_obs: Number of observations in the control and in the treatment together. + Default is `None`. n_resamples: The number of resamples performed to form the bootstrap distribution of a statistic. Default is `10_000`. power: Statistical power. Default is 0.8. @@ -176,4 +184,4 @@ def config_context( try: yield finally: - set_config(**old_config) + _global_config.update(**old_config) diff --git a/src/tea_tasting/metrics/mean.py b/src/tea_tasting/metrics/mean.py index b885100..1d50d57 100644 --- a/src/tea_tasting/metrics/mean.py +++ b/src/tea_tasting/metrics/mean.py @@ -23,11 +23,14 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing import Literal + from typing import Literal, TypeVar from tea_tasting.metrics.base import PowerParameter + N = TypeVar("N", bound=float | int | None) + + MAX_ITER = 100 @@ -227,39 +230,35 @@ def __init__( # noqa: PLR0913 "Both `effect_size` and `rel_effect_size` are not `None`. " "Only one of them should be defined.", ) - if effect_size is None: - self.effect_size = effect_size - else: - if not isinstance(effect_size, Sequence): - effect_size = (effect_size,) - self.effect_size = tuple( + if isinstance(effect_size, Sequence): + for x in effect_size: tea_tasting.utils.check_scalar( x, "effect_size", typ=float | int, gt=float("-inf"), lt=float("inf"), ne=0, ) - for x in effect_size + elif effect_size is not None: + tea_tasting.utils.check_scalar( + effect_size, "effect_size", typ=float | int, + gt=float("-inf"), lt=float("inf"), ne=0, ) - if rel_effect_size is None: - self.rel_effect_size = None - else: - if not isinstance(rel_effect_size, Sequence): - rel_effect_size = (rel_effect_size,) - self.rel_effect_size = tuple( + self.effect_size = effect_size + if isinstance(rel_effect_size, Sequence): + for x in rel_effect_size: tea_tasting.utils.check_scalar( x, "rel_effect_size", typ=float | int, gt=float("-inf"), lt=float("inf"), ne=0, ) - for x in rel_effect_size - ) - if n_obs is None: - self.n_obs = None - else: - if not isinstance(n_obs, Sequence): - n_obs = (n_obs,) - self.n_obs = tuple( - tea_tasting.utils.check_scalar(x, "n_obs", typ=int, gt=1) - for x in n_obs + elif rel_effect_size is not None: + tea_tasting.utils.check_scalar( + rel_effect_size, "rel_effect_size", typ=float | int, + gt=float("-inf"), lt=float("inf"), ne=0, ) + self.rel_effect_size = rel_effect_size + self.n_obs = ( + tea_tasting.utils.auto_check(n_obs, "n_obs") + if n_obs is not None + else tea_tasting.config.get_config("n_obs") + ) @property @@ -391,9 +390,9 @@ def _validate_power_parameters( parameter: PowerParameter, ) -> tuple[ float | None, # power - tuple[float | int | None, ...], # effect_size - tuple[float | None, ...], # rel_effect_size - tuple[int | None, ...], # n_obs + Sequence[float | int | None], # effect_size + Sequence[float | None], # rel_effect_size + Sequence[int | None], # n_obs ]: n_obs = None effect_size = None @@ -410,14 +409,14 @@ def _validate_power_parameters( self.effect_size if self.rel_effect_size is None else tuple( rel_effect_size * metric_mean - for rel_effect_size in self.rel_effect_size + for rel_effect_size in _to_seq(self.rel_effect_size) ) ) rel_effect_size = ( self.rel_effect_size if self.effect_size is None else tuple( effect_size / metric_mean - for effect_size in self.effect_size + for effect_size in _to_seq(self.effect_size) ) ) @@ -427,14 +426,7 @@ def _validate_power_parameters( if parameter in {"effect_size", "rel_effect_size", "n_obs"}: power = self.power - if effect_size is None: - effect_size = (None,) - if rel_effect_size is None: - rel_effect_size = (None,) - if n_obs is None: - n_obs = (None,) - - return power, effect_size, rel_effect_size, n_obs + return power, _to_seq(effect_size), _to_seq(rel_effect_size), _to_seq(n_obs) def _covariate_coef(self, aggr: tea_tasting.aggr.Aggregates) -> float: @@ -685,6 +677,12 @@ def _find_boundary( return b +def _to_seq(x: N | Sequence[N]) -> Sequence[N]: + if isinstance(x, Sequence): + return x + return (x,) + + class Mean(RatioOfMeans): # noqa: D101 def __init__( # noqa: PLR0913 self, @@ -698,9 +696,9 @@ def __init__( # noqa: PLR0913 alpha: float | None = None, ratio: float | int | None = None, power: float | None = None, - effect_size: float | int | None = None, - rel_effect_size: float | None = None, - n_obs: int | None = None, + effect_size: float | int | Sequence[float | int] | None = None, + rel_effect_size: float | Sequence[float] | None = None, + n_obs: int | Sequence[int] | None = None, ) -> None: """Metric for the analysis of means. diff --git a/src/tea_tasting/utils.py b/src/tea_tasting/utils.py index cfd3a81..e09522e 100644 --- a/src/tea_tasting/utils.py +++ b/src/tea_tasting/utils.py @@ -4,6 +4,7 @@ from __future__ import annotations import abc +from collections.abc import Sequence import inspect import locale import math @@ -13,7 +14,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Iterator, Sequence + from collections.abc import Callable, Iterator from typing import Any, Literal, TypeVar R = TypeVar("R") @@ -65,7 +66,7 @@ def check_scalar( # noqa: PLR0913 return value -def auto_check(value: R, name: str) -> R: +def auto_check(value: R, name: str) -> R: # noqa: C901, PLR0912 """Automatically check a parameter's type and value based on its name. Args: @@ -85,9 +86,16 @@ def auto_check(value: R, name: str) -> R: check_scalar(value, name, typ=bool) elif name == "equal_var": check_scalar(value, name, typ=bool) + elif name == "n_obs": + check_scalar(value, name, typ=int | Sequence | None) + if isinstance(value, int): + check_scalar(value, name, gt=1) + if isinstance(value, Sequence): + for val in value: + check_scalar(val, name, typ=int, gt=1) elif name == "n_resamples": check_scalar(value, name, typ=int, gt=0) - if name == "power": + elif name == "power": check_scalar(value, name, typ=float, gt=0, lt=1) elif name == "ratio": check_scalar(value, name, typ=float | int, gt=0) diff --git a/tests/metrics/test_mean.py b/tests/metrics/test_mean.py index 36313a8..b072f6b 100644 --- a/tests/metrics/test_mean.py +++ b/tests/metrics/test_mean.py @@ -125,8 +125,10 @@ def test_ratio_of_means_init_custom(): assert metric.ratio == 0.5 assert metric.power == 0.75 assert metric.effect_size is None - assert metric.rel_effect_size == (0.08,) + assert metric.rel_effect_size == 0.08 assert metric.n_obs == (5_000, 10_000) + metric = tea_tasting.metrics.mean.RatioOfMeans("a", effect_size=(1, 0.2)) + assert metric.effect_size == (1, 0.2) def test_ratio_of_means_init_config(): with tea_tasting.config.config_context( @@ -137,6 +139,7 @@ def test_ratio_of_means_init_config(): alpha=0.1, ratio=0.5, power=0.75, + n_obs=(5_000, 10_000), ): metric = tea_tasting.metrics.mean.RatioOfMeans("a") assert metric.alternative == "greater" @@ -146,6 +149,7 @@ def test_ratio_of_means_init_config(): assert metric.alpha == 0.1 assert metric.ratio == 0.5 assert metric.power == 0.75 + assert metric.n_obs == (5_000, 10_000) def test_ratio_of_means_aggr_cols(): diff --git a/tests/test_utils.py b/tests/test_utils.py index edd431d..8630992 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -84,6 +84,19 @@ def test_auto_check_equal_var(): with pytest.raises(TypeError): tea_tasting.utils.auto_check(0, "equal_var") +def test_auto_check_n_obs(): + assert tea_tasting.utils.auto_check(2, "n_obs") == 2 + assert tea_tasting.utils.auto_check((2, 3), "n_obs") == (2, 3) + assert tea_tasting.utils.auto_check(None, "n_obs") is None + with pytest.raises(TypeError): + tea_tasting.utils.auto_check(0.5, "n_obs") + with pytest.raises(TypeError): + tea_tasting.utils.auto_check((0.5, 2), "n_obs") + with pytest.raises(ValueError, match="must be >"): + tea_tasting.utils.auto_check(1, "n_obs") + with pytest.raises(ValueError, match="must be >"): + tea_tasting.utils.auto_check((1, 2), "n_obs") + def test_auto_check_n_resamples(): assert tea_tasting.utils.auto_check(1, "n_resamples") == 1 with pytest.raises(ValueError, match="must be >"):