Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add closed-form CRPS for exponential distribution #27

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/api/crps.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ When the true forecast CDF is not fully known, but represented by a finite ensem

<h2>Analytical formulations</h2>

::: scoringrules.crps_normal
::: scoringrules.crps_exponential

::: scoringrules.crps_lognormal

::: scoringrules.crps_normal

<h2>Ensemble-based estimators</h2>

Expand Down
2 changes: 2 additions & 0 deletions scoringrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from scoringrules._brier import brier_score
from scoringrules._crps import (
crps_ensemble,
crps_exponential,
crps_logistic,
crps_lognormal,
crps_normal,
Expand Down Expand Up @@ -35,6 +36,7 @@
"backends",
"crps_ensemble",
"crps_normal",
"crps_exponential",
"crps_lognormal",
"crps_logistic",
"crps_quantile",
Expand Down
48 changes: 45 additions & 3 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,48 @@ def vrcrps_ensemble(
)


def crps_exponential(
observation: "ArrayLike",
rate: "ArrayLike",
/,
*,
backend: "Backend" = None,
) -> "ArrayLike":
r"""Compute the closed form of the CRPS for the exponential distribution.

It is based on the following formulation from
[Jordan et al. (2019)](https://www.jstatsoft.org/article/view/v090i12):

$$\mathrm{CRPS}(F_{\lambda}, y) = |y| - \frac{2F_{\lambda}(y)}{\lambda} + \frac{1}{2 \lambda},$$

where $F_{\lambda}$ is exponential distribution function with rate parameter $\lambda > 0$.

Parameters
----------
observation:
The observed values.
rate:
Rate parameter of the forecast exponential distribution.

Returns
-------
score:
The CRPS between Exp(rate) and obs.

Examples
--------
```pycon
>>> import scoringrules as sr
>>> import numpy as np
>>> sr.crps_exponential(0.8, 3.0)
0.360478635526275
>>> sr.crps_exponential(np.array([0.8, 0.9]), np.array([3.0, 2.0]))
array([0.36047864, 0.24071795])
```
"""
return crps.exponential(observation, rate, backend=backend)


def crps_normal(
observation: "ArrayLike",
mu: "ArrayLike",
Expand Down Expand Up @@ -398,11 +440,11 @@ def crps_lognormal(

Parameters
----------
observations: ArrayLike
observation:
The observed values.
mulog: ArrayLike
mulog:
Mean of the normal underlying distribution.
sigmalog: ArrayLike
sigmalog:
Standard deviation of the underlying normal distribution.

Returns
Expand Down
6 changes: 6 additions & 0 deletions scoringrules/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def mean(
) -> "Array":
"""Calculate the arithmetic mean of the input array ``x``."""

@abc.abstractmethod
def max(
self, x: "Array", axis: int | tuple[int, ...] | None, keepdims: bool = False
) -> "Array":
"""Return the maximum value of an input array ``x``."""

@abc.abstractmethod
def moveaxis(
self,
Expand Down
5 changes: 5 additions & 0 deletions scoringrules/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def mean(
) -> "Array":
return jnp.mean(x, axis=axis, keepdims=keepdims)

def max(
self, x: "Array", axis: int | tuple[int, ...] | None, keepdims: bool = False
) -> "Array":
return jnp.max(x, axis=axis, keepdims=keepdims)

def moveaxis(
self,
x: "Array",
Expand Down
5 changes: 5 additions & 0 deletions scoringrules/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def mean(
) -> "NDArray":
return np.mean(x, axis=axis, keepdims=keepdims)

def max(
self, x: "NDArray", axis: int | tuple[int, ...] | None, keepdims: bool = False
) -> "NDArray":
return np.max(x, axis=axis, keepdims=keepdims)

def moveaxis(
self,
x: "NDArray",
Expand Down
5 changes: 5 additions & 0 deletions scoringrules/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def mean(
) -> "Tensor":
return tf.math.reduce_mean(x, axis=axis, keepdims=keepdims)

def max(
self, x: "Tensor", axis: int | tuple[int, ...] | None, keepdims: bool = False
) -> "Tensor":
return tf.math.reduce_max(x, axis=axis, keepdims=keepdims)

def moveaxis(
self,
x: "Tensor",
Expand Down
5 changes: 5 additions & 0 deletions scoringrules/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def mean(
) -> "Tensor":
return torch.mean(x, axis=axis, keepdim=keepdims)

def max(
self, x: "Tensor", axis: int | tuple[int, ...] | None, keepdims: bool = False
) -> "Tensor":
return torch.max(x, axis=axis, keepdim=keepdims)

def moveaxis(
self,
x: "Tensor",
Expand Down
7 changes: 4 additions & 3 deletions scoringrules/core/crps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from ._approx import ensemble, ow_ensemble, quantile_pinball, vr_ensemble
from ._closed import logistic, lognormal, normal
from ._closed import exponential, logistic, lognormal, normal
from ._gufuncs import estimator_gufuncs, quantile_pinball_gufunc

__all__ = [
"ensemble",
"ow_ensemble",
"vr_ensemble",
"normal",
"lognormal",
"exponential",
"logistic",
"lognormal",
"normal",
"estimator_gufuncs",
"quantile_pinball",
"quantile_pinball_gufunc",
Expand Down
12 changes: 11 additions & 1 deletion scoringrules/core/crps/_closed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import typing as tp

from scoringrules.backend import backends
from scoringrules.core.stats import _logis_cdf, _norm_cdf, _norm_pdf
from scoringrules.core.stats import _exp_cdf, _logis_cdf, _norm_cdf, _norm_pdf

if tp.TYPE_CHECKING:
from scoringrules.core.typing import Array, ArrayLike, Backend


def exponential(
obs: "ArrayLike", rate: "ArrayLike", backend: "Backend" = None
) -> "Array":
"""Compute the CRPS for the exponential distribution."""
B = backends.active if backend is None else backends[backend]
rate, obs = map(B.asarray, (rate, obs))
s = B.abs(obs) - (2 * _exp_cdf(obs, rate, backend=backend) / rate) + 1 / (2 * rate)
return s


def normal(
obs: "ArrayLike", mu: "ArrayLike", sigma: "ArrayLike", backend: "Backend" = None
) -> "Array":
Expand Down
Loading