From 03c70f347e99f35c13b6f0f00e58f9bf05d93ad6 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Wed, 12 Jun 2024 11:34:02 -0500 Subject: [PATCH 1/6] Support vector lengthscales for RBF and Matern kernels --- numpyro/contrib/hsgp/laplacian.py | 5 +- numpyro/contrib/hsgp/spectral_densities.py | 22 +++-- numpyro/contrib/hsgp/util.py | 10 +++ test/contrib/hsgp/test_approximation.py | 93 +++++++++++++++++++--- 4 files changed, 107 insertions(+), 23 deletions(-) create mode 100644 numpyro/contrib/hsgp/util.py diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 1ff4fa089..16c1d6c31 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -7,14 +7,13 @@ from __future__ import annotations -from typing import Union, get_args +from typing import get_args from jaxlib.xla_extension import ArrayImpl -import numpy as np import jax.numpy as jnp -ARRAY_TYPE = Union[ArrayImpl, np.ndarray] +from numpyro.contrib.hsgp.util import ARRAY_TYPE def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index 0d4d3db3d..d1787a722 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -16,8 +16,12 @@ from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues +def align_param(dim, param): + return jnp.broadcast_arrays(param, jnp.zeros(dim))[0] + + def spectral_density_squared_exponential( - dim: int, w: ArrayImpl, alpha: float, length: float + dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl ) -> float: """ Spectral density of the squared exponential kernel. @@ -44,13 +48,14 @@ def spectral_density_squared_exponential( :return: spectral density value :rtype: float """ - c = alpha * (jnp.sqrt(2 * jnp.pi) * length) ** dim - e = jnp.exp(-0.5 * (length**2) * jnp.dot(w, w)) + length = align_param(dim, length) + c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1) + e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2, axis=-1)) return c * e def spectral_density_matern( - dim: int, nu: float, w: ArrayImpl, alpha: float, length: float + dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl ) -> float: """ Spectral density of the Matérn kernel. @@ -79,6 +84,7 @@ def spectral_density_matern( :return: spectral density value :rtype: float """ # noqa: E501 + length = align_param(dim, length) c1 = ( alpha * (2 ** (dim)) @@ -86,15 +92,15 @@ def spectral_density_matern( * ((2 * nu) ** nu) * special.gamma(nu + dim / 2) ) - c2 = (2 * nu / (length**2) + jnp.dot(w, w)) ** (-nu - dim / 2) - c3 = special.gamma(nu) * length ** (2 * nu) + s = jnp.sum(length**2 * w**2, axis=-1) + c2 = jnp.prod(length, axis=-1) * (2 * nu + s) ** (-nu - dim / 2) + c3 = special.gamma(nu) return c1 * c2 / c3 -# TODO support length-D kernel hyperparameters def diag_spectral_density_squared_exponential( alpha: float, - length: float, + length: float | list[float], ell: float | int | list[float | int], m: int | list[int], dim: int, diff --git a/numpyro/contrib/hsgp/util.py b/numpyro/contrib/hsgp/util.py new file mode 100644 index 000000000..5afbcfb83 --- /dev/null +++ b/numpyro/contrib/hsgp/util.py @@ -0,0 +1,10 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import numpy as np + +import jax + +ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index a941652b1..dedf80938 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -74,23 +74,50 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: @pytest.mark.parametrize( - argnames="x1, x2, length, ell", + argnames="x1, x2, length, ell, xfail", argvalues=[ - (np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0), + (np.array([[1.0]]), np.array([[0.0]]), 1.0, 5.0, False), ( np.array([[1.5, 1.25]]), np.array([[0.0, 0.0]]), - np.array([1.0]), + 1.0, + 5.0, + False, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + np.array([1.0, 0.5]), + 5.0, + False, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + np.array( + [[1.0, 0.5], [0.5, 1.0]] + ), # different length scale for each point/dimension 5.0, + False, + ), + ( + np.array([[1.5, 1.25, 1.0]]), + np.array([[0.0, 0.0, 0.0]]), + np.array([[1.0, 0.5], [0.5, 1.0]]), # invalid length scale + 5.0, + True, ), ], ids=[ - "1d", - "2d,1d-length", + "1d,scalar-length", + "2d,scalar-length", + "2d,vector-length", + "2d,matrix-length", + "2d,invalid-length", ], ) def test_kernel_approx_squared_exponential( - x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float + x1: ArrayImpl, x2: ArrayImpl, length: float | ArrayImpl, ell: float, xfail: bool ): """ensure that the approximation of the squared exponential kernel is accurate, matching the exact kernel implementation from sklearn. @@ -100,13 +127,26 @@ def test_kernel_approx_squared_exponential( assert x1.shape == x2.shape m = 100 # large enough to ensure the approximation is accurate dim = x1.shape[-1] + if xfail: + with pytest.raises(ValueError): + diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) + return spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) eig_f1 = eigenfunctions(x1, ell=ell, m=m) eig_f2 = eigenfunctions(x2, ell=ell, m=m) approx = (eig_f1 * eig_f2) @ spd - exact = RBF(length)(x1, x2) - assert jnp.isclose(approx, exact, rtol=1e-3) + + def _exact_rbf(length): + return RBF(length)(x1, x2).squeeze(axis=-1) + + if isinstance(length, float | int): + exact = _exact_rbf(length) + elif length.ndim == 1: + exact = _exact_rbf(length) + else: + exact = np.apply_along_axis(_exact_rbf, axis=0, arr=length) + assert jnp.isclose(approx, exact, rtol=1e-3).all() @pytest.mark.parametrize( @@ -118,14 +158,32 @@ def test_kernel_approx_squared_exponential( np.array([[1.5, 1.25]]), np.array([[0.0, 0.0]]), 3 / 2, - np.array([1.0]), + np.array([0.25, 0.5]), 5.0, ), ( np.array([[1.5, 1.25]]), np.array([[0.0, 0.0]]), 5 / 2, - np.array([1.0]), + np.array([0.25, 0.5]), + 5.0, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + 3 / 2, + np.array( + [[1.0, 0.5], [0.5, 1.0]] + ), # different length scale for each point/dimension + 5.0, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + 5 / 2, + np.array( + [[1.0, 0.5], [0.5, 1.0]] + ), # different length scale for each point/dimension 5.0, ), ], @@ -134,6 +192,8 @@ def test_kernel_approx_squared_exponential( "1d,nu=5/2", "2d,nu=3/2,1d-length", "2d,nu=5/2,1d-length", + "2d,nu=3/2,2d-length", + "2d,nu=5/2,2d-length", ], ) def test_kernel_approx_squared_matern( @@ -154,8 +214,17 @@ def test_kernel_approx_squared_matern( eig_f1 = eigenfunctions(x1, ell=ell, m=m) eig_f2 = eigenfunctions(x2, ell=ell, m=m) approx = (eig_f1 * eig_f2) @ spd - exact = Matern(length_scale=length, nu=nu)(x1, x2) - assert jnp.isclose(approx, exact, rtol=1e-3) + + def _exact_matern(length): + return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1) + + if isinstance(length, float | int): + exact = _exact_matern(length) + elif length.ndim == 1: + exact = _exact_matern(length) + else: + exact = np.apply_along_axis(_exact_matern, axis=0, arr=length) + assert jnp.isclose(approx, exact, rtol=1e-3).all() @pytest.mark.parametrize( From f0cbaa90f7e1d445d36f90fa8a6b3317a97a2a8d Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 24 Jun 2024 08:56:09 -0500 Subject: [PATCH 2/6] Use broadcast_shapes to align params --- numpyro/contrib/hsgp/spectral_densities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index d1787a722..4762d5340 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -17,7 +17,7 @@ def align_param(dim, param): - return jnp.broadcast_arrays(param, jnp.zeros(dim))[0] + return jnp.broadcast_to(param, jnp.broadcast_shapes(jnp.shape(param), (dim,))) def spectral_density_squared_exponential( From 1d16f4e8234d9dd668d1bb527cbe62a79d406f94 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 24 Jun 2024 10:04:51 -0500 Subject: [PATCH 3/6] Remove union shorthand --- test/contrib/hsgp/test_approximation.py | 28 ++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index dedf80938..424a423ee 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -5,7 +5,7 @@ from functools import reduce from operator import mul -from typing import Literal +from typing import Literal, Union import numpy as np import pytest @@ -117,7 +117,11 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: ], ) def test_kernel_approx_squared_exponential( - x1: ArrayImpl, x2: ArrayImpl, length: float | ArrayImpl, ell: float, xfail: bool + x1: ArrayImpl, + x2: ArrayImpl, + length: Union[float, ArrayImpl], + ell: float, + xfail: bool, ): """ensure that the approximation of the squared exponential kernel is accurate, matching the exact kernel implementation from sklearn. @@ -140,7 +144,7 @@ def test_kernel_approx_squared_exponential( def _exact_rbf(length): return RBF(length)(x1, x2).squeeze(axis=-1) - if isinstance(length, float | int): + if isinstance(length, Union[float, int]): exact = _exact_rbf(length) elif length.ndim == 1: exact = _exact_rbf(length) @@ -218,7 +222,7 @@ def test_kernel_approx_squared_matern( def _exact_matern(length): return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1) - if isinstance(length, float | int): + if isinstance(length, Union[float, int]): exact = _exact_matern(length) elif length.ndim == 1: exact = _exact_matern(length) @@ -280,8 +284,8 @@ def test_approximation_squared_exponential( x: ArrayImpl, alpha: float, length: float, - ell: int | float | list[int | float], - m: int | list[int], + ell: Union[int, float, list[Union[int, float]]], + m: Union[int, list[int]], non_centered: bool, ): def model(x, alpha, length, ell, m, non_centered): @@ -332,8 +336,8 @@ def test_approximation_matern( nu: float, alpha: float, length: float, - ell: int | float | list[int | float], - m: int | list[int], + ell: Union[int, float, list[Union[int, float]]], + m: Union[int, list[int]], non_centered: bool, ): def model(x, nu, alpha, length, ell, m, non_centered): @@ -375,8 +379,8 @@ def model(x, nu, alpha, length, ell, m, non_centered): def test_squared_exponential_gp_model( synthetic_one_dim_data, synthetic_two_dim_data, - ell: float | int | list[float | int], - m: int | list[int], + ell: Union[float, int, list[Union[float, int]]], + m: Union[int, list[int]], non_centered: bool, num_dim: Literal[1, 2], ): @@ -433,8 +437,8 @@ def test_matern_gp_model( synthetic_one_dim_data, synthetic_two_dim_data, nu: float, - ell: int | float | list[float | int], - m: int | list[int], + ell: Union[int, float, list[Union[float, int]]], + m: Union[int, list[int]], non_centered: bool, num_dim: Literal[1, 2], ): From 6982850b20afb190b9d22e4bbf5de32a7a7e9945 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria <61420+samanklesaria@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:32:48 -0500 Subject: [PATCH 4/6] Update test/contrib/hsgp/test_approximation.py Co-authored-by: Juan Orduz --- test/contrib/hsgp/test_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 424a423ee..825659756 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -144,7 +144,7 @@ def test_kernel_approx_squared_exponential( def _exact_rbf(length): return RBF(length)(x1, x2).squeeze(axis=-1) - if isinstance(length, Union[float, int]): + if isinstance(length, int) | isinstance(length, float): exact = _exact_rbf(length) elif length.ndim == 1: exact = _exact_rbf(length) From 82c76a9eaae37d1e4f5f736e5173529f391d95d2 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 25 Jun 2024 11:13:33 -0500 Subject: [PATCH 5/6] Run make format --- test/contrib/hsgp/test_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 825659756..ffc1a173b 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -144,7 +144,7 @@ def test_kernel_approx_squared_exponential( def _exact_rbf(length): return RBF(length)(x1, x2).squeeze(axis=-1) - if isinstance(length, int) | isinstance(length, float): + if isinstance(length, int) | isinstance(length, float): exact = _exact_rbf(length) elif length.ndim == 1: exact = _exact_rbf(length) From 74c35931df35a93ef3790e383e6f1d20a34536a6 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 25 Jun 2024 12:01:49 -0500 Subject: [PATCH 6/6] Remove union in isinstance check --- test/contrib/hsgp/test_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index ffc1a173b..79ec1dd88 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -222,7 +222,7 @@ def test_kernel_approx_squared_matern( def _exact_matern(length): return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1) - if isinstance(length, Union[float, int]): + if isinstance(length, float) | isinstance(length, int): exact = _exact_matern(length) elif length.ndim == 1: exact = _exact_matern(length)