Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/conformal prediction #2552

Open
wants to merge 67 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
6071fc3
naive conformal prediction
dennisbader Jun 26, 2024
2dbf28b
first hist fc version works
dennisbader Jul 1, 2024
6c18c7e
add component names
dennisbader Jul 1, 2024
48d562a
add support for train length
dennisbader Jul 1, 2024
2c192ad
support for last points only
dennisbader Jul 1, 2024
cdbc6ce
add hist fc unit tests
dennisbader Jul 2, 2024
8c54a7d
add first conformal unit tests
dennisbader Jul 2, 2024
80dece9
overlap end checkpoint
dennisbader Jul 3, 2024
39adaf5
Merge branch 'master' into feat/conformal_prediction
dennisbader Jul 3, 2024
24bc75f
overlap end checkpoint 2
dennisbader Jul 4, 2024
b19b708
ignore start
dennisbader Jul 4, 2024
06d4a1c
Merge branch 'master' into feat/conformal_prediction
dennisbader Jul 5, 2024
7f02378
finalize hist fc test
dennisbader Jul 5, 2024
94acb96
start, train length tests
dennisbader Jul 5, 2024
f411178
Merge branch 'master' into feat/conformal_prediction
dennisbader Jul 5, 2024
c6f27ae
finalize start train length tests
dennisbader Jul 5, 2024
ba79d9a
fix residuals with overlap end
dennisbader Jul 8, 2024
c03eb17
refactor calibration for predict and hist fc
dennisbader Jul 9, 2024
d31f459
base and child conformal
dennisbader Jul 9, 2024
ff2beab
checks for calibration set
dennisbader Jul 10, 2024
522ee9d
rename conformal naive model
dennisbader Jul 10, 2024
1b40a40
add additional forecasting model logic
dennisbader Jul 10, 2024
7ee1488
add more unit tests
dennisbader Jul 11, 2024
6870580
add output chunk shift support
dennisbader Jul 12, 2024
01aaf0e
support train length with cal input
dennisbader Jul 19, 2024
c5dbf77
support train lenght part 2
dennisbader Jul 26, 2024
6847752
restructure hist fc logic
dennisbader Jul 27, 2024
13461c5
test with shorter covariates
dennisbader Jul 28, 2024
4143c20
add checks for min lengths
dennisbader Jul 30, 2024
5e2115c
corrections for minimum input
dennisbader Aug 30, 2024
d19e947
improve hist fc tests
dennisbader Aug 30, 2024
7a30a6d
Merge branch 'master' into feat/conformal_prediction
dennisbader Sep 15, 2024
c04a843
Merge branch 'master' into feat/conformal_prediction
dennisbader Sep 23, 2024
01e3d1e
make naive conformal model accept quantiles
dennisbader Sep 23, 2024
3f13619
add winkler score quantile interval metric
dennisbader Sep 25, 2024
5187630
update tests for quantile instead of alpha
dennisbader Sep 25, 2024
758f62e
Merge branch 'master' into feat/conformal_prediction
dennisbader Sep 25, 2024
fb6cfd2
add coverage metric and improve residuals and backtest
dennisbader Sep 26, 2024
880addb
add save load as in ensemble mode
dennisbader Sep 26, 2024
73bac08
quantile tests
dennisbader Sep 26, 2024
cc3e02b
remove checks
dennisbader Sep 26, 2024
e90431a
add non conformity scores for cqr
dennisbader Sep 27, 2024
5fb9d30
add conformalized quantile regression
dennisbader Sep 27, 2024
06a1c59
Merge branch 'master' into feat/conformal_prediction
dennisbader Sep 28, 2024
ff60254
allow all global prob models for ConformalQR
dennisbader Sep 29, 2024
5ab3631
add asymmetric naive model
dennisbader Sep 30, 2024
a4b0344
remove old code
dennisbader Sep 30, 2024
0cc20ac
add tests for asymetric naive mdoel
dennisbader Sep 30, 2024
f6802f0
add tests for cqr
dennisbader Oct 1, 2024
e8d922a
add progress bars
dennisbader Oct 1, 2024
9c5875c
add quantile sampler
dennisbader Oct 2, 2024
3761894
add predict lkl params and num samples
dennisbader Oct 3, 2024
a93ef39
add random method for handling randomness of non-torch models
dennisbader Oct 3, 2024
9318aea
fix all tests
dennisbader Oct 3, 2024
7c37f7d
code cleanup
dennisbader Oct 3, 2024
fe103e0
add probabilistic test
dennisbader Oct 3, 2024
60f9080
add conformal models to readme and covariates user guide
dennisbader Oct 3, 2024
298211d
fix failing tests
dennisbader Oct 3, 2024
a6f9056
improve docs
dennisbader Oct 3, 2024
6d1572d
add sketch of cp example notebook
dennisbader Oct 3, 2024
9881142
small update
dennisbader Oct 3, 2024
dd3e4c4
improve docs
dennisbader Oct 4, 2024
8e8ac1d
attempt to fix failing test on linux
dennisbader Oct 15, 2024
623b769
Merge branch 'master' into feat/conformal_prediction
dennisbader Nov 8, 2024
cc6fe07
update start logic
dennisbader Nov 10, 2024
c3d6abf
Merge branch 'master' into feat/conformal_prediction
dennisbader Nov 16, 2024
1272bfc
upgrade python target version
dennisbader Nov 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 46 additions & 40 deletions README.md

Large diffs are not rendered by default.

25 changes: 10 additions & 15 deletions darts/ad/anomaly_model/forecasting_am.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ def fit(
If set to 'value', `start` corresponds to the index value/label of the first predicted point. Will raise
an error if the value is not in `series`' index. Default: `'value'`
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for
deterministic models.
Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models.
verbose
Whether to print progress.
Whether to print the progress.
show_warnings
Whether to show warnings related to historical forecasts optimization, or parameters `start` and
`train_length`.
Expand Down Expand Up @@ -201,10 +200,9 @@ def score(
If set to 'value', `start` corresponds to the index value/label of the first predicted point. Will raise
an error if the value is not in `series`' index. Default: `'value'`
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for
deterministic models.
Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models.
verbose
Whether to print progress.
Whether to print the progress.
show_warnings
Whether to show warnings related to historical forecasts optimization, or parameters `start` and
`train_length`.
Expand Down Expand Up @@ -289,10 +287,9 @@ def predict_series(
If set to 'value', `start` corresponds to the index value/label of the first predicted point. Will raise
an error if the value is not in `series`' index. Default: `'value'`
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for
deterministic models.
Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models.
verbose
Whether to print progress.
Whether to print the progress.
show_warnings
Whether to show warnings related to historical forecasts optimization, or parameters `start` and
`train_length`.
Expand Down Expand Up @@ -385,10 +382,9 @@ def eval_metric(
If set to 'value', `start` corresponds to the index value/label of the first predicted point. Will raise
an error if the value is not in `series`' index. Default: `'value'`
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for
deterministic models.
Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models.
verbose
Whether to print progress.
Whether to print the progress.
show_warnings
Whether to show warnings related to historical forecasts optimization, or parameters `start` and
`train_length`.
Expand Down Expand Up @@ -491,10 +487,9 @@ def show_anomalies(
If set to 'value', `start` corresponds to the index value/label of the first predicted point. Will raise
an error if the value is not in `series`' index. Default: `'value'`
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for
deterministic models.
Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models.
verbose
Whether to print progress.
Whether to print the progress.
show_warnings
Whether to show warnings related to historical forecasts optimization, or parameters `start` and
`train_length`.
Expand Down
158 changes: 119 additions & 39 deletions darts/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,67 @@
and quantile forecasts. For probabilistic and quantile forecasts, use parameter `q` to define the quantile(s) to
compute the deterministic metrics on:

- Aggregated over time:
Absolute metrics:
- :func:`MERR <darts.metrics.metrics.merr>`: Mean Error
- :func:`MAE <darts.metrics.metrics.mae>`: Mean Absolute Error
- :func:`MSE <darts.metrics.metrics.mse>`: Mean Squared Error
- :func:`RMSE <darts.metrics.metrics.rmse>`: Root Mean Squared Error
- :func:`RMSLE <darts.metrics.metrics.rmsle>`: Root Mean Squared Log Error

Relative metrics:
- :func:`MASE <darts.metrics.metrics.mase>`: Mean Absolute Scaled Error
- :func:`MSSE <darts.metrics.metrics.msse>`: Mean Squared Scaled Error
- :func:`RMSSE <darts.metrics.metrics.rmsse>`: Root Mean Squared Scaled Error
- :func:`MAPE <darts.metrics.metrics.mape>`: Mean Absolute Percentage Error
- :func:`sMAPE <darts.metrics.metrics.smape>`: symmetric Mean Absolute Percentage Error
- :func:`OPE <darts.metrics.metrics.ope>`: Overall Percentage Error
- :func:`MARRE <darts.metrics.metrics.marre>`: Mean Absolute Ranged Relative Error

Other metrics:
- :func:`R2 <darts.metrics.metrics.r2_score>`: Coefficient of Determination
- :func:`CV <darts.metrics.metrics.coefficient_of_variation>`: Coefficient of Variation

- Per time step:
Absolute metrics:
- :func:`ERR <darts.metrics.metrics.err>`: Error
- :func:`AE <darts.metrics.metrics.ae>`: Absolute Error
- :func:`SE <darts.metrics.metrics.se>`: Squared Error
- :func:`SLE <darts.metrics.metrics.sle>`: Squared Log Error

Relative metrics:
- :func:`ASE <darts.metrics.metrics.ase>`: Absolute Scaled Error
- :func:`SSE <darts.metrics.metrics.sse>`: Squared Scaled Error
- :func:`APE <darts.metrics.metrics.ape>`: Absolute Percentage Error
- :func:`sAPE <darts.metrics.metrics.sape>`: symmetric Absolute Percentage Error
- :func:`ARRE <darts.metrics.metrics.arre>`: Absolute Ranged Relative Error

For probabilistic forecasts (storchastic predictions with `num_samples >> 1`):
- Aggregated over time:
- Aggregated over time:
Absolute metrics:
- :func:`MERR <darts.metrics.metrics.merr>`: Mean Error
- :func:`MAE <darts.metrics.metrics.mae>`: Mean Absolute Error
- :func:`MSE <darts.metrics.metrics.mse>`: Mean Squared Error
- :func:`RMSE <darts.metrics.metrics.rmse>`: Root Mean Squared Error
- :func:`RMSLE <darts.metrics.metrics.rmsle>`: Root Mean Squared Log Error

Relative metrics:
- :func:`MASE <darts.metrics.metrics.mase>`: Mean Absolute Scaled Error
- :func:`MSSE <darts.metrics.metrics.msse>`: Mean Squared Scaled Error
- :func:`RMSSE <darts.metrics.metrics.rmsse>`: Root Mean Squared Scaled Error
- :func:`MAPE <darts.metrics.metrics.mape>`: Mean Absolute Percentage Error
- :func:`sMAPE <darts.metrics.metrics.smape>`: symmetric Mean Absolute Percentage Error
- :func:`OPE <darts.metrics.metrics.ope>`: Overall Percentage Error
- :func:`MARRE <darts.metrics.metrics.marre>`: Mean Absolute Ranged Relative Error

Other metrics:
- :func:`R2 <darts.metrics.metrics.r2_score>`: Coefficient of Determination
- :func:`CV <darts.metrics.metrics.coefficient_of_variation>`: Coefficient of Variation

- Per time step:
Absolute metrics:
- :func:`ERR <darts.metrics.metrics.err>`: Error
- :func:`AE <darts.metrics.metrics.ae>`: Absolute Error
- :func:`SE <darts.metrics.metrics.se>`: Squared Error
- :func:`SLE <darts.metrics.metrics.sle>`: Squared Log Error

Relative metrics:
- :func:`ASE <darts.metrics.metrics.ase>`: Absolute Scaled Error
- :func:`SSE <darts.metrics.metrics.sse>`: Squared Scaled Error
- :func:`APE <darts.metrics.metrics.ape>`: Absolute Percentage Error
- :func:`sAPE <darts.metrics.metrics.sape>`: symmetric Absolute Percentage Error
- :func:`ARRE <darts.metrics.metrics.arre>`: Absolute Ranged Relative Error

For probabilistic forecasts (storchastic predictions with `num_samples >> 1`) and quantile forecasts:

- Aggregated over time:
Quantile metrics:
- :func:`MQL <darts.metrics.metrics.mql>`: Mean Quantile Loss
- :func:`QR <darts.metrics.metrics.qr>`: Quantile Risk

Quantile interval metrics:
- :func:`MIW <darts.metrics.metrics.miw>`: Mean Interval Width
- Per time step:
- :func:`MWS <darts.metrics.metrics.miws>`: Mean Interval Winkler Score
- :func:`MIC <darts.metrics.metrics.mic>`: Mean Interval Coverage
- :func:`MINCS_QR <darts.metrics.metrics.mincs_qr>`: Mean Interval Non-Conformity Score for Quantile Regression

- Per time step:
Quantile metrics:
- :func:`QL <darts.metrics.metrics.ql>`: Quantile Loss

Quantile interval metrics:
- :func:`IW <darts.metrics.metrics.iw>`: Interval Width
- :func:`WS <darts.metrics.metrics.iws>`: Interval Winkler Score
- :func:`IC <darts.metrics.metrics.ic>`: Interval Coverage
- :func:`INCS_QR <darts.metrics.metrics.incs_qr>`: Interval Non-Conformity Score for Quantile Regression

For Dynamic Time Warping (DTW) (aggregated over time):
- :func:`DTW <darts.metrics.metrics.dtw_metric>`: Dynamic Time Warping Metric

- :func:`DTW <darts.metrics.metrics.dtw_metric>`: Dynamic Time Warping Metric
"""

from darts.metrics.metrics import (
Expand All @@ -62,13 +77,19 @@
coefficient_of_variation,
dtw_metric,
err,
ic,
incs_qr,
iw,
iws,
mae,
mape,
marre,
mase,
merr,
mic,
mincs_qr,
miw,
miws,
mql,
mse,
msse,
Expand All @@ -86,6 +107,44 @@
sse,
)

ALL_METRICS = {
ae,
ape,
arre,
ase,
coefficient_of_variation,
dtw_metric,
err,
iw,
iws,
mae,
mape,
marre,
mase,
merr,
miw,
miws,
mql,
mse,
msse,
ope,
ql,
qr,
r2_score,
rmse,
rmsle,
rmsse,
sape,
se,
sle,
smape,
sse,
ic,
mic,
incs_qr,
mincs_qr,
}

TIME_DEPENDENT_METRICS = {
ae,
ape,
Expand All @@ -98,8 +157,23 @@
sle,
sse,
iw,
iws,
ic,
incs_qr,
}

Q_INTERVAL_METRICS = {
iw,
iws,
miw,
miws,
ic,
mic,
incs_qr,
}

NON_Q_METRICS = {dtw_metric}

__all__ = [
"ae",
"ape",
Expand Down Expand Up @@ -130,4 +204,10 @@
"sse",
"iw",
"miw",
"iws",
"miws",
"ic",
"mic",
"incs_qr",
"mincs_qr",
]
Loading
Loading