diff --git a/mizani/transforms.py b/mizani/transforms.py index eaa7724..1907556 100644 --- a/mizani/transforms.py +++ b/mizani/transforms.py @@ -56,7 +56,7 @@ ) if TYPE_CHECKING: - from typing import Any, Callable, Sequence, Type + from typing import Any, Sequence, Type from mizani.typing import ( BreaksFunction, @@ -101,6 +101,7 @@ ] UTC = ZoneInfo("UTC") +REGISTRY: dict[str, Type[trans]] = {} @dataclass(kw_only=True) @@ -122,6 +123,11 @@ class trans(ABC): minor_breaks_func: MinorBreaksFunction | None = None "Callable to calculate minor breaks" + def __init_subclass__(cls, *args, **kwargs): + # Register all subclasses + super().__init_subclass__(*args, **kwargs) + REGISTRY[cls.__name__] = cls + # Use type variables for trans.transform and trans.inverse # to help upstream packages avoid type mismatches. e.g. # transform(tuple[float, float]) -> tuple[float, float] @@ -907,15 +913,13 @@ def inverse(self, x: FloatArrayLike) -> NDArrayFloat: return np.sign(x) * (np.exp(np.abs(x)) - 1) # type: ignore -def gettrans( - t: str | Callable[[], Type[trans]] | Type[trans] | trans | None = None, -): +def gettrans(t: str | Type[trans] | trans | None = None): """ Return a trans object Parameters ---------- - t : str | callable | type | trans + t : str | type | trans Name of transformation function. If None, returns an identity transform. @@ -923,20 +927,16 @@ def gettrans( ------- out : trans """ - obj = t - # Make sure trans object is instantiated - if t is None: + if isinstance(t, str): + names = (f"{t}_trans", t) + for name in names: + if t := REGISTRY.get(name): + return t() + elif isinstance(t, trans): + return t + elif isinstance(t, type) and issubclass(t, trans): + return t() + elif t is None: return identity_trans() - if isinstance(obj, str): - name = "{}_trans".format(obj) - obj = globals()[name]() - if callable(obj): - obj = obj() - if isinstance(obj, type): - obj = obj() - - if not isinstance(obj, trans): - raise ValueError("Could not get transform object.") - - return obj + raise ValueError(f"Could not get transform object. {t}") diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c8ca878..e4a19d8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -20,9 +20,11 @@ log2_trans, log10_trans, log_trans, + logit_trans, modulus_trans, pd_timedelta_trans, probability_trans, + probit_trans, pseudo_log_trans, reciprocal_trans, reverse_trans, @@ -46,7 +48,9 @@ def test_gettrans(): t2 = gettrans(identity_trans) t3 = gettrans("identity") t4 = gettrans() - assert all(isinstance(x, identity_trans) for x in (t0, t1, t2, t3, t4)) + assert all( + x.__class__.__name__ == "identity_trans" for x in (t0, t1, t2, t3, t4) + ) with pytest.raises(ValueError): gettrans(object) @@ -197,6 +201,10 @@ def test_probability_trans(): npt.assert_allclose(xt[:3], 1 - xt[-3:][::-1]) npt.assert_allclose(x, x2) + # Cover the paths these create as well + logit_trans() + probit_trans() + def test_datetime_trans(): UTC = ZoneInfo("UTC")