diff --git a/README.md b/README.md index 00065e80..09109300 100644 --- a/README.md +++ b/README.md @@ -335,9 +335,9 @@ The currently implemented detectors are listed in the following table. Nishida and Yamauchi (2007) - Data drift - Batch - Distance based + Data drift + Batch + Distance based U N Anderson-Darling test @@ -355,6 +355,12 @@ The currently implemented detectors are listed in the following table. Earth Mover's distance Rubner et al. (2000) + + U + N + Energy distance + Székely et al. (2013) + U N diff --git a/docs/source/api_reference/detectors/data_drift/batch.md b/docs/source/api_reference/detectors/data_drift/batch.md index 0f87e646..0d764e83 100644 --- a/docs/source/api_reference/detectors/data_drift/batch.md +++ b/docs/source/api_reference/detectors/data_drift/batch.md @@ -26,6 +26,7 @@ The {mod}`frouros.detectors.data_drift.batch` module contains batch data drift d BhattacharyyaDistance EMD + EnergyDistance HellingerDistance HINormalizedComplement JS diff --git a/frouros/detectors/data_drift/__init__.py b/frouros/detectors/data_drift/__init__.py index 95baa837..211c7557 100644 --- a/frouros/detectors/data_drift/__init__.py +++ b/frouros/detectors/data_drift/__init__.py @@ -6,6 +6,7 @@ ChiSquareTest, CVMTest, EMD, + EnergyDistance, HellingerDistance, HINormalizedComplement, JS, @@ -25,6 +26,7 @@ "ChiSquareTest", "CVMTest", "EMD", + "EnergyDistance", "HellingerDistance", "HINormalizedComplement", "IncrementalKSTest", diff --git a/frouros/detectors/data_drift/batch/__init__.py b/frouros/detectors/data_drift/batch/__init__.py index 36adc7e0..20c09cb4 100644 --- a/frouros/detectors/data_drift/batch/__init__.py +++ b/frouros/detectors/data_drift/batch/__init__.py @@ -3,6 +3,7 @@ from .distance_based import ( BhattacharyyaDistance, EMD, + EnergyDistance, HellingerDistance, HINormalizedComplement, JS, @@ -25,6 +26,7 @@ "ChiSquareTest", "CVMTest", "EMD", + "EnergyDistance", "HellingerDistance", "HINormalizedComplement", "JS", diff --git a/frouros/detectors/data_drift/batch/distance_based/__init__.py b/frouros/detectors/data_drift/batch/distance_based/__init__.py index 8f46aee2..57ab88ad 100644 --- a/frouros/detectors/data_drift/batch/distance_based/__init__.py +++ b/frouros/detectors/data_drift/batch/distance_based/__init__.py @@ -2,6 +2,7 @@ from .bhattacharyya_distance import BhattacharyyaDistance from .emd import EMD +from .energy_distance import EnergyDistance from .hellinger_distance import HellingerDistance from .hi_normalized_complement import HINormalizedComplement from .js import JS @@ -12,6 +13,7 @@ __all__ = [ "BhattacharyyaDistance", "EMD", + "EnergyDistance", "HellingerDistance", "HINormalizedComplement", "JS", diff --git a/frouros/detectors/data_drift/batch/distance_based/energy_distance.py b/frouros/detectors/data_drift/batch/distance_based/energy_distance.py new file mode 100644 index 00000000..650b7bb0 --- /dev/null +++ b/frouros/detectors/data_drift/batch/distance_based/energy_distance.py @@ -0,0 +1,77 @@ +"""Energy Distance module.""" + +from typing import Optional, Union + +import numpy as np # type: ignore +from scipy.stats import energy_distance # type: ignore + +from frouros.callbacks.batch.base import BaseCallbackBatch +from frouros.detectors.data_drift.base import UnivariateData +from frouros.detectors.data_drift.batch.distance_based.base import ( + BaseDistanceBased, + DistanceResult, +) + + +class EnergyDistance(BaseDistanceBased): + """EnergyDistance [szekely2013energy]_ detector. + + :param callbacks: callbacks, defaults to None + :type callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] + :param kwargs: additional keyword arguments to pass to scipy.stats.energy_distance + :type kwargs: Dict[str, Any] + + :References: + + .. [szekely2013energy] Székely, Gábor J., and Maria L. Rizzo. + "Energy statistics: A class of statistics based on distances." + Journal of statistical planning and inference 143.8 (2013): 1249-1272. + + :Example: + + >>> from frouros.detectors.data_drift import EnergyDistance + >>> import numpy as np + >>> np.random.seed(seed=31) + >>> X = np.random.normal(loc=0, scale=1, size=100) + >>> Y = np.random.normal(loc=1, scale=1, size=100) + >>> detector = EnergyDistance() + >>> _ = detector.fit(X=X) + >>> detector.compare(X=Y)[0] + DistanceResult(distance=0.8359206395514527) + """ # noqa: E501 + + def __init__( # noqa: D107 + self, + callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None, + **kwargs, + ) -> None: + super().__init__( + statistical_type=UnivariateData(), + statistical_method=self._energy_distance, + statistical_kwargs=kwargs, + callbacks=callbacks, + ) + self.kwargs = kwargs + + def _distance_measure( + self, + X_ref: np.ndarray, # noqa: N803 + X: np.ndarray, # noqa: N803 + **kwargs, + ) -> DistanceResult: + emd = self._energy_distance(X=X_ref, Y=X, **self.kwargs) + distance = DistanceResult(distance=emd) + return distance + + @staticmethod + def _energy_distance( + X: np.ndarray, # noqa: N803 + Y: np.ndarray, + **kwargs, + ) -> float: + energy = energy_distance( + u_values=X.flatten(), + v_values=Y.flatten(), + **kwargs, + ) + return energy diff --git a/frouros/tests/integration/test_callback.py b/frouros/tests/integration/test_callback.py index 44c87fbd..e7c64c94 100644 --- a/frouros/tests/integration/test_callback.py +++ b/frouros/tests/integration/test_callback.py @@ -30,6 +30,7 @@ BhattacharyyaDistance, CVMTest, EMD, + EnergyDistance, HellingerDistance, HINormalizedComplement, JS, @@ -48,6 +49,7 @@ [ (BhattacharyyaDistance, 0.55516059, 0.0), (EMD, 3.85346006, 0.0), + (EnergyDistance, 2.11059982, 0.0), (HellingerDistance, 0.74509099, 0.0), (HINormalizedComplement, 0.78, 0.0), (JS, 0.67010107, 0.0), diff --git a/frouros/tests/integration/test_data_drift.py b/frouros/tests/integration/test_data_drift.py index 37973059..0ecf6d6c 100644 --- a/frouros/tests/integration/test_data_drift.py +++ b/frouros/tests/integration/test_data_drift.py @@ -8,6 +8,7 @@ from frouros.detectors.data_drift.batch import ( BhattacharyyaDistance, EMD, + EnergyDistance, HellingerDistance, HINormalizedComplement, PSI, @@ -64,6 +65,7 @@ def test_batch_distance_based_categorical( "detector, expected_distance", [ (EMD(), 3.85346006), + (EnergyDistance(), 2.11059982), (JS(), 0.67010107), (KL(), np.inf), (HINormalizedComplement(), 0.78),