Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Replace prts metrics #2400

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
301641d
Pre-commit fixes
aryanpola Nov 26, 2024
c4db216
Merge branch 'aeon-toolkit:main' into recall
aryanpola Nov 26, 2024
cc1101a
Position parameter in calculate_bias
aryanpola Nov 26, 2024
31ed73d
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Nov 26, 2024
1028942
Added recall metric
aryanpola Nov 30, 2024
d4dc5ca
merged into into one file
aryanpola Dec 3, 2024
4db8027
test added
aryanpola Dec 20, 2024
4baaec7
Merge branch 'main' into recall
aryanpola Dec 20, 2024
43cd9ac
Merge branch 'main' into recall
aryanpola Dec 23, 2024
c098731
Changes in test and range_metrics
aryanpola Dec 29, 2024
497362f
list of list running but error!
aryanpola Dec 29, 2024
ab87680
flattening lists, all cases passed
aryanpola Dec 30, 2024
446e058
Merge branch 'main' into recall
aryanpola Dec 30, 2024
c18af4f
Empty-Commit
aryanpola Dec 30, 2024
010d994
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Dec 30, 2024
9c23582
changes
aryanpola Jan 14, 2025
df42934
Protected functions
aryanpola Jan 14, 2025
dfa9046
Merge branch 'main' into recall
aryanpola Jan 14, 2025
b5bfab4
Changes in documentation
aryanpola Jan 15, 2025
576aaae
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Jan 15, 2025
da81823
Changed test cases into seperate functions
aryanpola Jan 15, 2025
f9732eb
test cases added and added range recall
aryanpola Jan 17, 2025
48238f3
udf_gamma removed from precision
aryanpola Jan 17, 2025
0561981
changes
aryanpola Jan 17, 2025
4f4f617
more changes
aryanpola Jan 17, 2025
26b5029
recommended changes
aryanpola Jan 20, 2025
fa60406
changes
aryanpola Jan 20, 2025
c48d426
Added Parameters
aryanpola Jan 23, 2025
b13ba4a
removed udf_gamma from precision
aryanpola Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aeon/benchmarking/metrics/anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"range_pr_auc_score",
"range_pr_vus_score",
"range_roc_vus_score",
"ts_precision",
"ts_recall",
"ts_fscore",
]

from aeon.benchmarking.metrics.anomaly_detection._binary import (
Expand All @@ -35,3 +38,8 @@
range_roc_auc_score,
range_roc_vus_score,
)
from aeon.benchmarking.metrics.anomaly_detection.range_metrics import (
ts_fscore,
ts_precision,
ts_recall,
)
222 changes: 222 additions & 0 deletions aeon/benchmarking/metrics/anomaly_detection/range_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""Calculate Precision, Recall, and F1-Score for time series anomaly detection."""

__maintainer__ = []
__all__ = ["ts_precision", "ts_recall", "ts_fscore"]


def _flatten_ranges(ranges):
"""
If the input is a list of lists, it flattens it into a single list.

Parameters
----------
ranges : list of tuples or list of lists of tuples
The ranges to flatten.

Returns
-------
list of tuples
A flattened list of ranges.
"""
if not ranges:
return []
if isinstance(ranges[0], list):
flat = []
for sublist in ranges:
for pred in sublist:
flat.append(pred)
return flat
return ranges


def _calculate_bias(position, length, bias_type="flat"):
"""Calculate bias value based on position and length.

Parameters
----------
position : int
Current position in the range
length : int
Total length of the range
bias_type : str
Type of bias to apply, Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
"""
if bias_type == "flat":
return 1.0
elif bias_type == "front":
return 1.0 - (position - 1) / length
elif bias_type == "middle":
return 1.0 - abs(2 * (position - 1) / (length - 1) - 1) if length > 1 else 1.0
elif bias_type == "back":
return position / length
else:
raise ValueError(f"Invalid bias type: {bias_type}")


def _gamma_select(cardinality, gamma, udf_gamma=None):
"""Select a gamma value based on the cardinality type."""
if gamma == "one":
return 1.0
elif gamma == "reciprocal":
return 1 / cardinality if cardinality > 1 else 1.0
elif gamma == "udf_gamma":
if udf_gamma is not None:
return 1.0 / udf_gamma
else:
raise ValueError("udf_gamma must be provided for 'udf_gamma' gamma type.")
else:
raise ValueError("Invalid gamma type.")


def ts_precision(y_pred, y_real, gamma="one", bias_type="flat", udf_gamma=None):
"""
Calculate Global Precision for time series anomaly detection.

Parameters
----------
y_pred : list of tuples or list of lists of tuples
The predicted ranges.
y_real : list of tuples or list of lists of tuples
The real (actual) ranges.
gamma : str
Cardinality type. Should be one of ["reciprocal", "one", "udf_gamma"].
(default: "one")
bias_type : str
Type of bias to apply. Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
udf_gamma : int or None
User-defined gamma value. (default: None)

Returns
-------
float
Global Precision
"""
# Flattening y_pred and y_real to resolve nested lists
flat_y_pred = _flatten_ranges(y_pred)
flat_y_real = _flatten_ranges(y_real)

overlapping_weighted_positions = 0.0
total_pred_weight = 0.0

for pred_range in flat_y_pred:
start_pred, end_pred = pred_range
length_pred = end_pred - start_pred + 1

for i in range(1, length_pred + 1):
pos = start_pred + i - 1
bias = _calculate_bias(i, length_pred, bias_type)

# Check if the position is in any real range
in_real = any(
real_start <= pos <= real_end for real_start, real_end in flat_y_real
)

if in_real:
gamma_value = _gamma_select(1, gamma, udf_gamma)
overlapping_weighted_positions += bias * gamma_value

total_pred_weight += bias

precision = (
overlapping_weighted_positions / total_pred_weight
if total_pred_weight > 0
else 0.0
)
return precision


def ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0, udf_gamma=None):
"""
Calculate Global Recall for time series anomaly detection.

Parameters
----------
y_pred : list of tuples or list of lists of tuples
The predicted ranges.
y_real : list of tuples or list of lists of tuples
The real (actual) ranges.
gamma : str
Cardinality type. Should be one of ["reciprocal", "one", "udf_gamma"].
(default: "one")
bias_type : str
Type of bias to apply. Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
alpha : float
Weight for existence reward in recall calculation. (default: 0.0)
udf_gamma : int or None
User-defined gamma value. (default: None)

Returns
-------
float
Global Recall
"""
# Flattening y_pred and y_real
flat_y_pred = _flatten_ranges(y_pred)
flat_y_real = _flatten_ranges(y_real)

overlapping_weighted_positions = 0.0
total_real_weight = 0.0

for real_range in flat_y_real:
start_real, end_real = real_range
length_real = end_real - start_real + 1

for i in range(1, length_real + 1):
pos = start_real + i - 1
bias = _calculate_bias(i, length_real, bias_type)

# Check if the position is in any predicted range
in_pred = any(
pred_start <= pos <= pred_end for pred_start, pred_end in flat_y_pred
)

if in_pred:
gamma_value = _gamma_select(1, gamma, udf_gamma)
overlapping_weighted_positions += bias * gamma_value

total_real_weight += bias

recall = (
overlapping_weighted_positions / total_real_weight
if total_real_weight > 0
else 0.0
)
return recall


def ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0, udf_gamma=None):
"""
Calculate F1-Score for time series anomaly detection.

Parameters
----------
y_pred : list of tuples or list of lists of tuples
The predicted ranges.
y_real : list of tuples or list of lists of tuples
The real (actual) ranges.
gamma : str
Cardinality type. Should be one of ["reciprocal", "one", "udf_gamma"].
(default: "one")
bias_type : str
Type of bias to apply. Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
udf_gamma : int or None
User-defined gamma value. (default: None)

Returns
-------
float
F1-Score
"""
precision = ts_precision(y_pred, y_real, gamma, bias_type, udf_gamma=udf_gamma)
recall = ts_recall(y_pred, y_real, gamma, bias_type, alpha, udf_gamma=udf_gamma)

if precision + recall > 0:
fscore = 2 * (precision * recall) / (precision + recall)
else:
fscore = 0.0

return fscore
aryanpola marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Test cases for the range-based anomaly detection metrics."""

import numpy as np
import pytest

from aeon.benchmarking.metrics.anomaly_detection.range_metrics import (
ts_fscore,
ts_precision,
ts_recall,
)


# Test cases for metrics
@pytest.mark.parametrize(
"y_pred, y_real, expected_precision, expected_recall, expected_f1",
[
([(1, 4)], [(2, 6)], 0.750000, 0.600000, 0.666667), # Single Overlapping Range
(
[(1, 2), (7, 8)],
[(3, 4), (9, 10)],
0.000000,
0.000000,
0.000000,
), # Multiple Non-Overlapping Ranges
(
[(1, 3), (5, 7)],
[(2, 6), (8, 10)],
0.666667,
0.500000,
0.571429,
), # Multiple Overlapping Ranges
(
[[(1, 3), (5, 7)], [(10, 12)]],
[(2, 6), (8, 10)],
0.555556,
0.625000,
0.588235,
), # Nested Lists of Predictions
(
[(1, 10)],
[(2, 3), (5, 6), (8, 9)],
0.600000,
1.000000,
0.750000,
), # All Encompassing Range
(
[(1, 2)],
[(1, 1)],
0.5,
1.000000,
0.666667,
), # Converted Binary to Range-Based(Existing example)
],
)
def test_metrics(y_pred, y_real, expected_precision, expected_recall, expected_f1):
"""Test the range-based anomaly detection metrics."""
precision = ts_precision(y_pred, y_real, gamma="one", bias_type="flat")
recall = ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)
f1_score = ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)

# Use assertions with detailed error messages for debugging
np.testing.assert_almost_equal(
precision,
expected_precision,
decimal=6,
err_msg=f"Precision failed! Expected={expected_precision}, Got={precision}",
)
np.testing.assert_almost_equal(
recall,
expected_recall,
decimal=6,
err_msg=f"Recall failed! Expected={expected_recall}, Got={recall}",
)
np.testing.assert_almost_equal(
f1_score,
expected_f1,
decimal=6,
err_msg=f"F1-Score failed! Expected={expected_f1}, Got={f1_score}",
)