Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantile loss metric #1559

Merged
merged 24 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
247b7ef
first implementation of quantile loss
JanFidor Feb 11, 2023
c1d5317
add quantile loss to metrics ___init__ and tests
JanFidor Feb 11, 2023
1f5db41
refactor
JanFidor Feb 11, 2023
8d4a3ae
rename pinball loss to quantile loss
JanFidor Feb 11, 2023
c1ff762
black
JanFidor Feb 11, 2023
9f85c3e
use reduction to aggregate losses and update docs
JanFidor Feb 12, 2023
d0bbf0d
black + isort
JanFidor Feb 14, 2023
dcb8265
rollback to simple mean instead of reduction param
JanFidor Feb 15, 2023
189c71c
change overlooked copy-paste comment
JanFidor Feb 15, 2023
941b484
black enter
JanFidor Feb 15, 2023
9decba7
Merge branch 'master' into feature/quantile-loss-metric
hrzn Feb 16, 2023
54e6124
docs changes
JanFidor Feb 16, 2023
681af77
Merge branch 'feature/quantile-loss-metric' of https://github.com/Jan…
JanFidor Feb 16, 2023
c313e3c
Merge branch 'master' into feature/quantile-loss-metric
madtoinou Feb 17, 2023
9ccf0f9
flake8
JanFidor Feb 17, 2023
8507e7f
Merge branch 'feature/quantile-loss-metric' of https://github.com/Jan…
JanFidor Feb 17, 2023
4e8fe87
Merge branch 'master' into feature/quantile-loss-metric
JanFidor Feb 17, 2023
6f81ea9
Merge branch 'master' into feature/quantile-loss-metric
hrzn Feb 21, 2023
b55d5ab
Merge branch 'master' into feature/quantile-loss-metric
JanFidor Feb 24, 2023
8ac5d1c
Merge branch 'master' into feature/quantile-loss-metric
madtoinou Feb 26, 2023
f7f7b66
Merge branch 'master' into feature/quantile-loss-metric
dennisbader Feb 28, 2023
25448fe
Merge branch 'master' into feature/quantile-loss-metric
madtoinou Feb 28, 2023
170e59f
Merge branch 'master' into feature/quantile-loss-metric
madtoinou Mar 5, 2023
f5c4098
Merge branch 'master' into feature/quantile-loss-metric
madtoinou Mar 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions darts/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
mase,
mse,
ope,
quantile_loss,
r2_score,
rho_risk,
rmse,
Expand Down
78 changes: 78 additions & 0 deletions darts/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,81 @@ def rho_risk(

rho_loss = 2 * (z_true - z_hat_rho) * (rho * pred_below - (1 - rho) * pred_above)
return rho_loss / z_true


# Quantile Loss (Pinball Loss)
@multi_ts_support
@multivariate_support
def quantile_loss(
actual_series: Union[TimeSeries, Sequence[TimeSeries]],
pred_series: Union[TimeSeries, Sequence[TimeSeries]],
tau: float = 0.5,
intersect: bool = True,
*,
reduction: Callable[[np.ndarray], float] = np.mean,
inter_reduction: Callable[[np.ndarray], Union[float, np.ndarray]] = lambda x: x,
n_jobs: int = 1,
verbose: bool = False
) -> float:
"""
Given a time series of actual values :math:`y` of length :math:`T` and a time series of stochastic predictions
(containing N samples) :math:`y'` of shape :math:`T x N`, quantile loss is a metric that quantifies the
accuracy of a specific quantile :math:`q` from the predicted value distribution.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
actual_series
The (sequence of) actual series.
pred_series
The (sequence of) predicted series.
tau
The quantile (float [0, 1]) of interest for the loss.
intersect
For time series that are overlapping in time without having the same time index, setting `True`
will consider the values only over their common time interval (intersection in time).
reduction
Function taking as input a ``np.ndarray`` and returning a scalar value. This function is used to aggregate
the metrics of different components in case of multivariate ``TimeSeries`` instances.
inter_reduction
Function taking as input a ``np.ndarray`` and returning either a scalar value or a ``np.ndarray``.
This function can be used to aggregate the metrics of different series in case the metric is evaluated on a
``Sequence[TimeSeries]``. Defaults to the identity function, which returns the pairwise metrics for each pair
of ``TimeSeries`` received in input. Example: ``inter_reduction=np.mean``, will return the average of the
pairwise metrics.
n_jobs
The number of jobs to run in parallel. Parallel jobs are created only when a ``Sequence[TimeSeries]`` is
passed as input, parallelising operations regarding different ``TimeSeries``. Defaults to `1`
(sequential). Setting the parameter to `-1` means using all the available processors.
verbose
Optionally, whether to print operations progress

Returns
-------
float
The quantile loss metric
"""

raise_if_not(
pred_series.is_stochastic,
"quantile (quantile) loss should only be computed for stochastic predicted TimeSeries.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
)

y, y_hat = _get_values_or_raise(
actual_series,
pred_series,
intersect,
stochastic_quantile=None,
remove_nan_union=True,
)

ts_length = y.shape[0]
sample_size = 1 if len(y_hat.shape) < 3 else y_hat.shape[2]
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

y = y.reshape(ts_length, -1, 1).repeat(sample_size, axis=2)
y_hat = y_hat.reshape(
ts_length, -1, sample_size
) # make sure y shape == y_hat shape

errors = y - y_hat
losses = np.maximum((tau - 1) * errors, tau * errors)
return losses.mean()
30 changes: 30 additions & 0 deletions darts/tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,36 @@ def test_rho_risk(self):
self.assertAlmostEqual(metrics.rho_risk(s1, s12_stochastic, rho=0.0), 0.0)
self.assertAlmostEqual(metrics.rho_risk(s2, s12_stochastic, rho=1.0), 0.0)

def test_quantile_loss(self):
# deterministic not supported
with self.assertRaises(ValueError):
metrics.quantile_loss(self.series1, self.series1)

# general univariate, multivariate and multi-ts tests
self.helper_test_multivariate_duplication_equality(
metrics.quantile_loss, is_stochastic=True
)
self.helper_test_multiple_ts_duplication_equality(
metrics.quantile_loss, is_stochastic=True
)
self.helper_test_nan(metrics.quantile_loss, is_stochastic=True)

# test perfect predictions -> risk = 0
for tau in [0.25, 0.5]:
self.assertAlmostEqual(
metrics.quantile_loss(self.series1, self.series11_stochastic, tau=tau),
0.0,
)

# test whether stochastic sample from two TimeSeries (ts) represents the individual ts at 0. and 1. quantiles
s1 = self.series1
s2 = self.series1 * 2
s12_stochastic = TimeSeries.from_times_and_values(
s1.time_index, np.stack([s1.values(), s2.values()], axis=2)
)
self.assertAlmostEqual(metrics.quantile_loss(s1, s12_stochastic, tau=1.0), 0.0)
self.assertAlmostEqual(metrics.quantile_loss(s2, s12_stochastic, tau=0.0), 0.0)

def test_metrics_arguments(self):
series00 = self.series0.stack(self.series0)
series11 = self.series1.stack(self.series1)
Expand Down