diff --git a/README.md b/README.md
index 00065e80..09109300 100644
--- a/README.md
+++ b/README.md
@@ -335,9 +335,9 @@ The currently implemented detectors are listed in the following table.
Nishida and Yamauchi (2007) |
- Data drift |
- Batch |
- Distance based |
+ Data drift |
+ Batch |
+ Distance based |
U |
N |
Anderson-Darling test |
@@ -355,6 +355,12 @@ The currently implemented detectors are listed in the following table.
Earth Mover's distance |
Rubner et al. (2000) |
+
+ U |
+ N |
+ Energy distance |
+ Székely et al. (2013) |
+
U |
N |
diff --git a/docs/source/api_reference/detectors/data_drift/batch.md b/docs/source/api_reference/detectors/data_drift/batch.md
index 0f87e646..0d764e83 100644
--- a/docs/source/api_reference/detectors/data_drift/batch.md
+++ b/docs/source/api_reference/detectors/data_drift/batch.md
@@ -26,6 +26,7 @@ The {mod}`frouros.detectors.data_drift.batch` module contains batch data drift d
BhattacharyyaDistance
EMD
+ EnergyDistance
HellingerDistance
HINormalizedComplement
JS
diff --git a/frouros/detectors/data_drift/__init__.py b/frouros/detectors/data_drift/__init__.py
index 95baa837..211c7557 100644
--- a/frouros/detectors/data_drift/__init__.py
+++ b/frouros/detectors/data_drift/__init__.py
@@ -6,6 +6,7 @@
ChiSquareTest,
CVMTest,
EMD,
+ EnergyDistance,
HellingerDistance,
HINormalizedComplement,
JS,
@@ -25,6 +26,7 @@
"ChiSquareTest",
"CVMTest",
"EMD",
+ "EnergyDistance",
"HellingerDistance",
"HINormalizedComplement",
"IncrementalKSTest",
diff --git a/frouros/detectors/data_drift/batch/__init__.py b/frouros/detectors/data_drift/batch/__init__.py
index 36adc7e0..20c09cb4 100644
--- a/frouros/detectors/data_drift/batch/__init__.py
+++ b/frouros/detectors/data_drift/batch/__init__.py
@@ -3,6 +3,7 @@
from .distance_based import (
BhattacharyyaDistance,
EMD,
+ EnergyDistance,
HellingerDistance,
HINormalizedComplement,
JS,
@@ -25,6 +26,7 @@
"ChiSquareTest",
"CVMTest",
"EMD",
+ "EnergyDistance",
"HellingerDistance",
"HINormalizedComplement",
"JS",
diff --git a/frouros/detectors/data_drift/batch/distance_based/__init__.py b/frouros/detectors/data_drift/batch/distance_based/__init__.py
index 8f46aee2..57ab88ad 100644
--- a/frouros/detectors/data_drift/batch/distance_based/__init__.py
+++ b/frouros/detectors/data_drift/batch/distance_based/__init__.py
@@ -2,6 +2,7 @@
from .bhattacharyya_distance import BhattacharyyaDistance
from .emd import EMD
+from .energy_distance import EnergyDistance
from .hellinger_distance import HellingerDistance
from .hi_normalized_complement import HINormalizedComplement
from .js import JS
@@ -12,6 +13,7 @@
__all__ = [
"BhattacharyyaDistance",
"EMD",
+ "EnergyDistance",
"HellingerDistance",
"HINormalizedComplement",
"JS",
diff --git a/frouros/detectors/data_drift/batch/distance_based/energy_distance.py b/frouros/detectors/data_drift/batch/distance_based/energy_distance.py
new file mode 100644
index 00000000..650b7bb0
--- /dev/null
+++ b/frouros/detectors/data_drift/batch/distance_based/energy_distance.py
@@ -0,0 +1,77 @@
+"""Energy Distance module."""
+
+from typing import Optional, Union
+
+import numpy as np # type: ignore
+from scipy.stats import energy_distance # type: ignore
+
+from frouros.callbacks.batch.base import BaseCallbackBatch
+from frouros.detectors.data_drift.base import UnivariateData
+from frouros.detectors.data_drift.batch.distance_based.base import (
+ BaseDistanceBased,
+ DistanceResult,
+)
+
+
+class EnergyDistance(BaseDistanceBased):
+ """EnergyDistance [szekely2013energy]_ detector.
+
+ :param callbacks: callbacks, defaults to None
+ :type callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]]
+ :param kwargs: additional keyword arguments to pass to scipy.stats.energy_distance
+ :type kwargs: Dict[str, Any]
+
+ :References:
+
+ .. [szekely2013energy] Székely, Gábor J., and Maria L. Rizzo.
+ "Energy statistics: A class of statistics based on distances."
+ Journal of statistical planning and inference 143.8 (2013): 1249-1272.
+
+ :Example:
+
+ >>> from frouros.detectors.data_drift import EnergyDistance
+ >>> import numpy as np
+ >>> np.random.seed(seed=31)
+ >>> X = np.random.normal(loc=0, scale=1, size=100)
+ >>> Y = np.random.normal(loc=1, scale=1, size=100)
+ >>> detector = EnergyDistance()
+ >>> _ = detector.fit(X=X)
+ >>> detector.compare(X=Y)[0]
+ DistanceResult(distance=0.8359206395514527)
+ """ # noqa: E501
+
+ def __init__( # noqa: D107
+ self,
+ callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ statistical_type=UnivariateData(),
+ statistical_method=self._energy_distance,
+ statistical_kwargs=kwargs,
+ callbacks=callbacks,
+ )
+ self.kwargs = kwargs
+
+ def _distance_measure(
+ self,
+ X_ref: np.ndarray, # noqa: N803
+ X: np.ndarray, # noqa: N803
+ **kwargs,
+ ) -> DistanceResult:
+ emd = self._energy_distance(X=X_ref, Y=X, **self.kwargs)
+ distance = DistanceResult(distance=emd)
+ return distance
+
+ @staticmethod
+ def _energy_distance(
+ X: np.ndarray, # noqa: N803
+ Y: np.ndarray,
+ **kwargs,
+ ) -> float:
+ energy = energy_distance(
+ u_values=X.flatten(),
+ v_values=Y.flatten(),
+ **kwargs,
+ )
+ return energy
diff --git a/frouros/tests/integration/test_callback.py b/frouros/tests/integration/test_callback.py
index 44c87fbd..e7c64c94 100644
--- a/frouros/tests/integration/test_callback.py
+++ b/frouros/tests/integration/test_callback.py
@@ -30,6 +30,7 @@
BhattacharyyaDistance,
CVMTest,
EMD,
+ EnergyDistance,
HellingerDistance,
HINormalizedComplement,
JS,
@@ -48,6 +49,7 @@
[
(BhattacharyyaDistance, 0.55516059, 0.0),
(EMD, 3.85346006, 0.0),
+ (EnergyDistance, 2.11059982, 0.0),
(HellingerDistance, 0.74509099, 0.0),
(HINormalizedComplement, 0.78, 0.0),
(JS, 0.67010107, 0.0),
diff --git a/frouros/tests/integration/test_data_drift.py b/frouros/tests/integration/test_data_drift.py
index 37973059..0ecf6d6c 100644
--- a/frouros/tests/integration/test_data_drift.py
+++ b/frouros/tests/integration/test_data_drift.py
@@ -8,6 +8,7 @@
from frouros.detectors.data_drift.batch import (
BhattacharyyaDistance,
EMD,
+ EnergyDistance,
HellingerDistance,
HINormalizedComplement,
PSI,
@@ -64,6 +65,7 @@ def test_batch_distance_based_categorical(
"detector, expected_distance",
[
(EMD(), 3.85346006),
+ (EnergyDistance(), 2.11059982),
(JS(), 0.67010107),
(KL(), np.inf),
(HINormalizedComplement(), 0.78),