Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CRPS for the gamma distribution #40

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/crps.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ When the true forecast CDF is not fully known, but represented by a finite ensem

::: scoringrules.crps_exponentialM

::: scoringrules.crps_gamma

::: scoringrules.crps_lognormal

::: scoringrules.crps_normal
Expand Down
2 changes: 2 additions & 0 deletions scoringrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
crps_ensemble,
crps_exponential,
crps_exponentialM,
crps_gamma,
crps_logistic,
crps_lognormal,
crps_normal,
Expand Down Expand Up @@ -44,6 +45,7 @@
"crps_normal",
"crps_exponential",
"crps_exponentialM",
"crps_gamma",
"crps_lognormal",
"crps_logistic",
"crps_quantile",
Expand Down
60 changes: 60 additions & 0 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,66 @@ def crps_exponentialM(
return crps.exponentialM(observation, mass, location, scale, backend=backend)


def crps_gamma(
observation: "ArrayLike",
shape: "ArrayLike",
/,
rate: "ArrayLike | None" = None,
*,
scale: "ArrayLike | None" = None,
backend: "Backend" = None,
) -> "ArrayLike":
r"""Compute the closed form of the CRPS for the gamma distribution.

It is based on the following formulation from
[Scheuerer and Möller (2015)](doi: doi:10.1214/15-AOAS843):

$$ \mathrm{CRPS}(F_{\alpha, \beta}, y) = y(2F_{\alpha, \beta}(y) - 1)
- \frac{\alpha}{\beta} (2 F_{\alpha + 1, \beta}(y) - 1)
- \frac{1}{\beta B(1/2, \alpha)}. $$

where $F_{\alpha, \beta}$ is gamma distribution function with shape
parameter $\alpha > 0$ and rate parameter $\beta > 0$ (equivalently,
with scale parameter $1/\beta$).

Parameters
----------
observation:
The observed values.
shape:
Shape parameter of the forecast gamma distribution.
rate:
Rate parameter of the forecast rate distribution.
scale:
Scale parameter of the forecast scale distribution, where `scale = 1 / rate`.

Returns
-------
score:
The CRPS between obs and Gamma(shape, rate).

Examples
--------
>>> import scoringrules as sr
>>> sr.crps_gamma(0.2, 1.1, 0.1)
5.503536008961291

Raises
------
ValueError
If both `rate` and `scale` are provided, or if neither is provided.
"""
if (scale is None and rate is None) or (scale is not None and rate is not None):
raise ValueError(
"Either `rate` or `scale` must be provided, but not both or neither."
)

if rate is None:
rate = 1.0 / scale

return crps.gamma(observation, shape, rate, backend=backend)


def crps_normal(
observation: "ArrayLike",
mu: "ArrayLike",
Expand Down
4 changes: 4 additions & 0 deletions scoringrules/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def mbessel1(self, x: "Array", /) -> "Array":
def gamma(self, x: "Array", /) -> "Array":
"""Calculate the gamma function at each element ``x_i`` of the input array ``x``."""

@abc.abstractmethod
def gammainc(self, x: "Array", y: "Array", /) -> "Array":
"""Calculate the regularised lower incomplete gamma function at each element ``x_i`` of the input array ``x``."""

@abc.abstractmethod
def gammalinc(self, x: "Array", y: "Array", /) -> "Array":
"""Calculate the lower incomplete gamma function at each element ``x_i`` of the input array ``x``."""
Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def mbessel1(self, x: "Array") -> "Array":
def gamma(self, x: "Array") -> "Array":
return jsp.special.gamma(x)

def gammainc(self, x: "Array", y: "Array") -> "Array":
return jsp.special.gammainc(x, y)

def gammalinc(self, x: "Array", y: "Array") -> "Array":
return jsp.special.gammainc(x, y) * jsp.special.gamma(x)

Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def mbessel1(self, x: "NDArray") -> "NDArray":
def gamma(self, x: "NDArray") -> "NDArray":
return gamma(x)

def gammainc(self, x: "NDArray", y: "NDArray") -> "NDArray":
return gammainc(x, y)

def gammalinc(self, x: "NDArray", y: "NDArray") -> "NDArray":
return gammainc(x, y) * gamma(x)

Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def mbessel1(self, x: "Tensor") -> "Tensor":
def gamma(self, x: "Tensor") -> "Tensor":
return tf.math.exp(tf.math.lgamma(x))

def gammainc(self, x: "Tensor", y: "Tensor") -> "Tensor":
return tf.math.igamma(x, y)

def gammalinc(self, x: "Tensor", y: "Tensor") -> "Tensor":
return tf.math.igamma(x, y) * tf.math.exp(tf.math.lgamma(x))

Expand Down
8 changes: 3 additions & 5 deletions scoringrules/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def mbessel1(self, x: "Tensor") -> "Tensor":
def gamma(self, x: "Tensor") -> "Tensor":
return torch.exp(torch.lgamma(x))

def gammainc(self, x: "Tensor", y: "Tensor") -> "Tensor":
return torch.special.gammainc(x, y)

def gammalinc(self, x: "Tensor", y: "Tensor") -> "Tensor":
return torch.special.gammainc(x, y) * torch.exp(torch.lgamma(x))

Expand All @@ -229,8 +232,3 @@ def expi(self, x: "Tensor") -> "Tensor":

def where(self, condition: "Tensor", x: "Tensor", y: "Tensor") -> "Tensor":
return torch.where(condition, x, y)


if __name__ == "__main__":
B = TorchBackend()
out = B.mean(torch.ones(10))
2 changes: 2 additions & 0 deletions scoringrules/core/crps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
binomial,
exponential,
exponentialM,
gamma,
logistic,
lognormal,
normal,
Expand All @@ -18,6 +19,7 @@
"binomial",
"exponential",
"exponentialM",
"gamma",
"logistic",
"lognormal",
"normal",
Expand Down
20 changes: 20 additions & 0 deletions scoringrules/core/crps/_closed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_binom_cdf,
_binom_pdf,
_exp_cdf,
_gamma_cdf,
_logis_cdf,
_norm_cdf,
_norm_pdf,
Expand Down Expand Up @@ -145,6 +146,25 @@ def exponentialM(
return s


def gamma(
obs: "ArrayLike",
shape: "ArrayLike",
rate: "ArrayLike",
backend: "Backend" = None,
) -> "Array":
"""Compute the CRPS for the gamma distribution."""
B = backends.active if backend is None else backends[backend]
obs, shape, rate = map(B.asarray, (obs, shape, rate))
F_ab = _gamma_cdf(obs, shape, rate, backend=backend)
F_ab1 = _gamma_cdf(obs, shape + 1, rate, backend=backend)
s = (
obs * (2 * F_ab - 1)
- (shape / rate) * (2 * F_ab1 - 1)
- 1 / (rate * B.beta(B.asarray(0.5), shape))
)
return s


def normal(
obs: "ArrayLike", mu: "ArrayLike", sigma: "ArrayLike", backend: "Backend" = None
) -> "Array":
Expand Down
3 changes: 2 additions & 1 deletion scoringrules/core/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def _gamma_cdf(
) -> "Array":
"""Cumulative distribution function for the gamma distribution."""
B = backends.active if backend is None else backends[backend]
return B.max(B.li_gamma(shape, rate * x) / B.gamma(shape), 0)
zero = B.asarray(0.0)
return B.maximum(B.gammainc(shape, rate * B.maximum(x, zero)), zero)


def _pois_cdf(x: "ArrayLike", mean: "ArrayLike", backend: "Backend" = None) -> "Array":
Expand Down
20 changes: 20 additions & 0 deletions tests/test_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,23 @@ def test_exponentialM(backend):
res = _crps.crps_exponentialM(obs, mass, location, scale, backend=backend)
expected = 0.751013
assert np.isclose(res, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_gamma(backend):
obs, shape, rate = 0.2, 1.1, 0.7
expected = 0.6343718

res = _crps.crps_gamma(obs, shape, rate, backend=backend)
assert np.isclose(res, expected)

res = _crps.crps_gamma(obs, shape, scale=1 / rate, backend=backend)
assert np.isclose(res, expected)

with pytest.raises(ValueError):
_crps.crps_gamma(obs, shape, rate, scale=1 / rate, backend=backend)
return

with pytest.raises(ValueError):
_crps.crps_gamma(obs, shape, backend=backend)
return
Loading