Skip to content

Commit

Permalink
Add crps and log score for mixture of normal distributions (#73)
Browse files Browse the repository at this point in the history
* fix lint issues in crps for mixture of normals

* add log score for mixture of normals
  • Loading branch information
sallen12 authored Oct 3, 2024
1 parent 97ebb7f commit 49ee8de
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api/crps.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ also be viewed as the Brier score integrated over all real-valued thresholds.

::: scoringrules.crps_lognormal

::: scoringrules.crps_mixnorm

::: scoringrules.crps_negbinom

::: scoringrules.crps_normal
Expand Down
2 changes: 2 additions & 0 deletions docs/api/logarithmic.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

::: scoringrules.logs_lognormal

::: scoringrules.logs_mixnorm

::: scoringrules.logs_negbinom

::: scoringrules.logs_normal
Expand Down
4 changes: 4 additions & 0 deletions scoringrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
crps_loglaplace,
crps_loglogistic,
crps_lognormal,
crps_mixnorm,
crps_negbinom,
crps_normal,
crps_2pnormal,
Expand Down Expand Up @@ -58,6 +59,7 @@
logs_logistic,
logs_loglogistic,
logs_lognormal,
logs_mixnorm,
logs_negbinom,
logs_normal,
logs_2pnormal,
Expand Down Expand Up @@ -106,6 +108,7 @@
"crps_loglaplace",
"crps_loglogistic",
"crps_lognormal",
"crps_mixnorm",
"crps_negbinom",
"crps_normal",
"crps_2pnormal",
Expand All @@ -128,6 +131,7 @@
"logs_logistic",
"logs_loglogistic",
"logs_lognormal",
"logs_mixnorm",
"logs_negbinom",
"logs_normal",
"logs_2pnormal",
Expand Down
63 changes: 63 additions & 0 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,68 @@ def crps_lognormal(
return crps.lognormal(observation, mulog, sigmalog, backend=backend)


def crps_mixnorm(
observation: "ArrayLike",
m: "ArrayLike",
s: "ArrayLike",
/,
w: "ArrayLike" = None,
axis: "ArrayLike" = -1,
*,
backend: "Backend" = None,
) -> "ArrayLike":
r"""Compute the closed form of the CRPS for a mixture of normal distributions.
It is based on the following formulation from
[Grimit et al. (2006)](https://doi.org/10.1256/qj.05.235):
$$ \mathrm{CRPS}(F, y) = \sum_{i=1}^{M} w_{i} A(y - \mu_{i}, \sigma_{i}^{2}) - \frac{1}{2} \sum_{i=1}^{M} \sum_{j=1}^{M} w_{i} w_{j} A(\mu_{i} - \mu_{j}, \sigma_{i}^{2} + \sigma_{j}^{2}), $$
where $F(x) = \sum_{i=1}^{M} w_{i} \Phi \left( \frac{x - \mu_{i}}{\sigma_{i}} \right)$,
and $A(\mu, \sigma^{2}) = \mu (2 \Phi(\frac{\mu}{\sigma}) - 1) + 2\sigma \phi(\frac{\mu}{\sigma}).$
Parameters
----------
observation: ArrayLike
The observed values.
m: ArrayLike
Means of the component normal distributions.
s: ArrayLike
Standard deviations of the component normal distributions.
w: ArrayLike
Non-negative weights assigned to each component.
axis: int
The axis corresponding to the mixture components. Default is the last axis.
backend:
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
Returns
-------
score:
The CRPS between MixNormal(m, s) and obs.
Examples
--------
>>> import scoringrules as sr
>>> sr.crps_mixnormal(0.0, [0.1, -0.3, 1.0], [0.4, 2.1, 0.7], [0.1, 0.2, 0.7])
"""
B = backends.active if backend is None else backends[backend]
observation, m, s = map(B.asarray, (observation, m, s))

if w is None:
M: int = m.shape[axis]
w = B.zeros(m.shape) + 1 / M
else:
w = B.asarray(w)

if axis != -1:
m = B.moveaxis(m, axis, -1)
s = B.moveaxis(s, axis, -1)
w = B.moveaxis(w, axis, -1)

return crps.mixnorm(observation, m, s, w, backend=backend)


def crps_negbinom(
observation: "ArrayLike",
n: "ArrayLike",
Expand Down Expand Up @@ -1853,6 +1915,7 @@ def crps_uniform(
"crps_loglaplace",
"crps_loglogistic",
"crps_lognormal",
"crps_mixnorm",
"crps_negbinom",
"crps_normal",
"crps_2pnormal",
Expand Down
59 changes: 59 additions & 0 deletions scoringrules/_logs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as tp

from scoringrules.backend import backends
from scoringrules.core import logarithmic

if tp.TYPE_CHECKING:
Expand Down Expand Up @@ -523,6 +524,62 @@ def logs_lognormal(
return logarithmic.lognormal(observation, mulog, sigmalog, backend=backend)


def logs_mixnorm(
observation: "ArrayLike",
m: "ArrayLike",
s: "ArrayLike",
/,
w: "ArrayLike" = None,
axis: "ArrayLike" = -1,
*,
backend: "Backend" = None,
) -> "ArrayLike":
r"""Compute the logarithmic score for a mixture of normal distributions.
This score is equivalent to the negative log likelihood of the normal mixture distribution
Parameters
----------
observation: ArrayLike
The observed values.
m: ArrayLike
Means of the component normal distributions.
s: ArrayLike
Standard deviations of the component normal distributions.
w: ArrayLike
Non-negative weights assigned to each component.
axis: int
The axis corresponding to the mixture components. Default is the last axis.
backend:
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
Returns
-------
score:
The LS between MixNormal(m, s) and obs.
Examples
--------
>>> import scoringrules as sr
>>> sr.logs_mixnormal(0.0, [0.1, -0.3, 1.0], [0.4, 2.1, 0.7], [0.1, 0.2, 0.7])
"""
B = backends.active if backend is None else backends[backend]
observation, m, s = map(B.asarray, (observation, m, s))

if w is None:
M: int = m.shape[axis]
w = B.zeros(m.shape) + 1 / M
else:
w = B.asarray(w)

if axis != -1:
m = B.moveaxis(m, axis, -1)
s = B.moveaxis(s, axis, -1)
w = B.moveaxis(w, axis, -1)

return logarithmic.mixnorm(observation, m, s, w, backend=backend)


def logs_negbinom(
observation: "ArrayLike",
n: "ArrayLike",
Expand Down Expand Up @@ -899,6 +956,8 @@ def logs_uniform(
"logs_logistic",
"logs_loglogistic",
"logs_lognormal",
"logs_mixnorm",
"logs_negbinom",
"logs_normal",
"logs_2pnormal",
"logs_poisson",
Expand Down
2 changes: 2 additions & 0 deletions scoringrules/core/crps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
loglaplace,
loglogistic,
lognormal,
mixnorm,
negbinom,
normal,
poisson,
Expand Down Expand Up @@ -51,6 +52,7 @@
"loglaplace",
"loglogistic",
"lognormal",
"mixnorm",
"negbinom",
"normal",
"poisson",
Expand Down
25 changes: 25 additions & 0 deletions scoringrules/core/crps/_closed.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,31 @@ def lognormal(
)


def mixnorm(
obs: "ArrayLike",
m: "ArrayLike",
s: "ArrayLike",
w: "ArrayLike",
backend: "Backend" = None,
) -> "Array":
"""Compute the CRPS for a mixture of normal distributions."""
B = backends.active if backend is None else backends[backend]
m, s, w, obs = map(B.asarray, (m, s, w, obs))

m_y = obs[..., None] - m
m_X = m[..., None] - m[..., None, :]
s_X = B.sqrt(s[..., None] ** 2 + s[..., None, :] ** 2)
w_X = w[..., None] * w[..., None, :]

A_y = m_y * (2 * _norm_cdf(m_y / s) - 1) + 2 * s * _norm_pdf(m_y / s)
A_X = m_X * (2 * _norm_cdf(m_X / s_X) - 1) + 2 * s_X * _norm_pdf(m_X / s_X)

sc_1 = B.sum(w * A_y, axis=-1)
sc_2 = B.sum(w_X * A_X, axis=(-1, -2))

return sc_1 - 0.5 * sc_2


def negbinom(
obs: "ArrayLike",
n: "ArrayLike",
Expand Down
17 changes: 17 additions & 0 deletions scoringrules/core/logarithmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,23 @@ def lognormal(
return s


def mixnorm(
obs: "ArrayLike",
m: "ArrayLike",
s: "ArrayLike",
w: "ArrayLike",
backend: "Backend" = None,
) -> "Array":
"""Compute the logarithmic score for a mixture of normal distributions."""
B = backends.active if backend is None else backends[backend]
m, s, w, obs = map(B.asarray, (m, s, w, obs))

z = (obs[..., None] - m) / s
prob = _norm_pdf(z, backend=backend) / s
prob = B.sum(w * prob, axis=-1)
return -B.log(prob)


def negbinom(
obs: "ArrayLike",
n: "ArrayLike",
Expand Down
26 changes: 26 additions & 0 deletions tests/test_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,32 @@ def test_lognormal(backend):
assert not np.any(res - 0.0 > 0.0001)


@pytest.mark.parametrize("backend", BACKENDS)
def test_mixnorm(backend):
obs, m, s, w = 0.3, [0.0, -2.9, 0.9], [0.5, 1.4, 0.7], [1 / 3, 1 / 3, 1 / 3]
res = _crps.crps_mixnorm(obs, m, s, w, backend=backend)
expected = 0.4510451
assert np.isclose(res, expected)

res0 = _crps.crps_mixnorm(obs, m, s, backend=backend)
assert np.isclose(res, res0)

w = [0.3, 0.1, 0.6]
res = _crps.crps_mixnorm(obs, m, s, w, backend=backend)
expected = 0.2354619
assert np.isclose(res, expected)

obs = [-1.6, 0.3]
m = [[0.0, -2.9], [0.6, 0.0], [-1.1, -2.3]]
s = [[0.5, 1.7], [1.1, 0.7], [1.4, 1.5]]
res1 = _crps.crps_mixnorm(obs, m, s, axis=0, backend=backend)

m = [[0.0, 0.6, -1.1], [-2.9, 0.0, -2.3]]
s = [[0.5, 1.1, 1.4], [1.7, 0.7, 1.5]]
res2 = _crps.crps_mixnorm(obs, m, s, backend=backend)
assert np.allclose(res1, res2)


@pytest.mark.parametrize("backend", BACKENDS)
def test_negbinom(backend):
if backend in ["jax", "torch", "tensorflow"]:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,32 @@ def test_loglaplace(backend):
assert np.isclose(res, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_mixnorm(backend):
obs, m, s, w = 0.3, [0.0, -2.9, 0.9], [0.5, 1.4, 0.7], [1 / 3, 1 / 3, 1 / 3]
res = _logs.logs_mixnorm(obs, m, s, w, backend=backend)
expected = 1.019742
assert np.isclose(res, expected)

res0 = _logs.logs_mixnorm(obs, m, s, backend=backend)
assert np.isclose(res, res0)

w = [0.3, 0.1, 0.6]
res = _logs.logs_mixnorm(obs, m, s, w, backend=backend)
expected = 0.8235977
assert np.isclose(res, expected)

obs = [-1.6, 0.3]
m = [[0.0, -2.9], [0.6, 0.0], [-1.1, -2.3]]
s = [[0.5, 1.7], [1.1, 0.7], [1.4, 1.5]]
res1 = _logs.logs_mixnorm(obs, m, s, axis=0, backend=backend)

m = [[0.0, 0.6, -1.1], [-2.9, 0.0, -2.3]]
s = [[0.5, 1.1, 1.4], [1.7, 0.7, 1.5]]
res2 = _logs.logs_mixnorm(obs, m, s, backend=backend)
assert np.allclose(res1, res2)


@pytest.mark.parametrize("backend", BACKENDS)
def test_negbinom(backend):
if backend in ["jax", "torch"]:
Expand Down

0 comments on commit 49ee8de

Please sign in to comment.