Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vector lengthscales for RBF and Matern kernels #1819

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

from __future__ import annotations

from typing import Union, get_args
from typing import get_args

from jaxlib.xla_extension import ArrayImpl
import numpy as np

import jax.numpy as jnp

ARRAY_TYPE = Union[ArrayImpl, np.ndarray]
from numpyro.contrib.hsgp.util import ARRAY_TYPE


def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:
Expand Down
22 changes: 14 additions & 8 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues


def align_param(dim, param):
return jnp.broadcast_to(param, jnp.broadcast_shapes(jnp.shape(param), (dim,)))


def spectral_density_squared_exponential(
dim: int, w: ArrayImpl, alpha: float, length: float
dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
"""
Spectral density of the squared exponential kernel.
Expand All @@ -44,13 +48,14 @@ def spectral_density_squared_exponential(
:return: spectral density value
:rtype: float
"""
c = alpha * (jnp.sqrt(2 * jnp.pi) * length) ** dim
e = jnp.exp(-0.5 * (length**2) * jnp.dot(w, w))
length = align_param(dim, length)
c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1)
e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2, axis=-1))
return c * e


def spectral_density_matern(
dim: int, nu: float, w: ArrayImpl, alpha: float, length: float
dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
"""
Spectral density of the Matérn kernel.
Expand Down Expand Up @@ -79,22 +84,23 @@ def spectral_density_matern(
:return: spectral density value
:rtype: float
""" # noqa: E501
length = align_param(dim, length)
c1 = (
alpha
* (2 ** (dim))
* (jnp.pi ** (dim / 2))
* ((2 * nu) ** nu)
* special.gamma(nu + dim / 2)
)
c2 = (2 * nu / (length**2) + jnp.dot(w, w)) ** (-nu - dim / 2)
c3 = special.gamma(nu) * length ** (2 * nu)
s = jnp.sum(length**2 * w**2, axis=-1)
c2 = jnp.prod(length, axis=-1) * (2 * nu + s) ** (-nu - dim / 2)
c3 = special.gamma(nu)
return c1 * c2 / c3


# TODO support length-D kernel hyperparameters
def diag_spectral_density_squared_exponential(
alpha: float,
length: float,
length: float | list[float],
ell: float | int | list[float | int],
m: int | list[int],
dim: int,
Expand Down
10 changes: 10 additions & 0 deletions numpyro/contrib/hsgp/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import numpy as np

import jax

ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this type can be used in other NumPyro modules.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ArrayImpl is only an issue when a model gets compiled and the arrays turn into tracers. isinstance(X, jax.Array) will work for both jax arrays and tracers.

Copy link
Contributor

@brendancooley brendancooley Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some details here https://jax.readthedocs.io/en/latest/jax_array_migration.html

I believe this is best practice for typing jax arrays (as of last year), but I am not sure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am definitively not an expert in type hints, so following the recommendation from the docs seems the safest path :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I mark this thread as resolved, as this seems to be in line with the recommendation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

115 changes: 94 additions & 21 deletions test/contrib/hsgp/test_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from functools import reduce
from operator import mul
from typing import Literal
from typing import Literal, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -74,23 +74,54 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]:


@pytest.mark.parametrize(
argnames="x1, x2, length, ell",
argnames="x1, x2, length, ell, xfail",
argvalues=[
(np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0),
(np.array([[1.0]]), np.array([[0.0]]), 1.0, 5.0, False),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0]),
1.0,
5.0,
False,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0, 0.5]),
5.0,
False,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
False,
),
(
np.array([[1.5, 1.25, 1.0]]),
np.array([[0.0, 0.0, 0.0]]),
np.array([[1.0, 0.5], [0.5, 1.0]]), # invalid length scale
5.0,
True,
),
],
ids=[
"1d",
"2d,1d-length",
"1d,scalar-length",
"2d,scalar-length",
"2d,vector-length",
"2d,matrix-length",
"2d,invalid-length",
],
)
def test_kernel_approx_squared_exponential(
x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float
x1: ArrayImpl,
x2: ArrayImpl,
length: Union[float, ArrayImpl],
ell: float,
xfail: bool,
):
"""ensure that the approximation of the squared exponential kernel is accurate,
matching the exact kernel implementation from sklearn.
Expand All @@ -100,13 +131,26 @@ def test_kernel_approx_squared_exponential(
assert x1.shape == x2.shape
m = 100 # large enough to ensure the approximation is accurate
dim = x1.shape[-1]
if xfail:
with pytest.raises(ValueError):
diag_spectral_density_squared_exponential(1.0, length, ell, m, dim)
return
spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim)

eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
approx = (eig_f1 * eig_f2) @ spd
exact = RBF(length)(x1, x2)
assert jnp.isclose(approx, exact, rtol=1e-3)

def _exact_rbf(length):
return RBF(length)(x1, x2).squeeze(axis=-1)

if isinstance(length, int) | isinstance(length, float):
exact = _exact_rbf(length)
elif length.ndim == 1:
exact = _exact_rbf(length)
else:
exact = np.apply_along_axis(_exact_rbf, axis=0, arr=length)
assert jnp.isclose(approx, exact, rtol=1e-3).all()


@pytest.mark.parametrize(
Expand All @@ -118,14 +162,32 @@ def test_kernel_approx_squared_exponential(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
np.array([1.0]),
np.array([0.25, 0.5]),
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
np.array([0.25, 0.5]),
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
np.array([1.0]),
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
),
],
Expand All @@ -134,6 +196,8 @@ def test_kernel_approx_squared_exponential(
"1d,nu=5/2",
"2d,nu=3/2,1d-length",
"2d,nu=5/2,1d-length",
"2d,nu=3/2,2d-length",
"2d,nu=5/2,2d-length",
],
)
def test_kernel_approx_squared_matern(
Expand All @@ -154,8 +218,17 @@ def test_kernel_approx_squared_matern(
eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
approx = (eig_f1 * eig_f2) @ spd
exact = Matern(length_scale=length, nu=nu)(x1, x2)
assert jnp.isclose(approx, exact, rtol=1e-3)

def _exact_matern(length):
return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1)

if isinstance(length, float) | isinstance(length, int):
exact = _exact_matern(length)
elif length.ndim == 1:
exact = _exact_matern(length)
else:
exact = np.apply_along_axis(_exact_matern, axis=0, arr=length)
assert jnp.isclose(approx, exact, rtol=1e-3).all()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -211,8 +284,8 @@ def test_approximation_squared_exponential(
x: ArrayImpl,
alpha: float,
length: float,
ell: int | float | list[int | float],
m: int | list[int],
ell: Union[int, float, list[Union[int, float]]],
m: Union[int, list[int]],
non_centered: bool,
):
def model(x, alpha, length, ell, m, non_centered):
Expand Down Expand Up @@ -263,8 +336,8 @@ def test_approximation_matern(
nu: float,
alpha: float,
length: float,
ell: int | float | list[int | float],
m: int | list[int],
ell: Union[int, float, list[Union[int, float]]],
m: Union[int, list[int]],
non_centered: bool,
):
def model(x, nu, alpha, length, ell, m, non_centered):
Expand Down Expand Up @@ -306,8 +379,8 @@ def model(x, nu, alpha, length, ell, m, non_centered):
def test_squared_exponential_gp_model(
synthetic_one_dim_data,
synthetic_two_dim_data,
ell: float | int | list[float | int],
m: int | list[int],
ell: Union[float, int, list[Union[float, int]]],
m: Union[int, list[int]],
non_centered: bool,
num_dim: Literal[1, 2],
):
Expand Down Expand Up @@ -364,8 +437,8 @@ def test_matern_gp_model(
synthetic_one_dim_data,
synthetic_two_dim_data,
nu: float,
ell: int | float | list[float | int],
m: int | list[int],
ell: Union[int, float, list[Union[float, int]]],
m: Union[int, list[int]],
non_centered: bool,
num_dim: Literal[1, 2],
):
Expand Down
Loading