Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/more kernels v2 electric boogaloo #742

Merged
merged 11 commits into from
Aug 23, 2024
1 change: 1 addition & 0 deletions .cspell/custom_misc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ kernelised
kernelized
KSD
linewidth
Matérn
ml.p3.8xlarge
ndmin
parsable
Expand Down
2 changes: 2 additions & 0 deletions .cspell/library_terms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ mathbb
mathbf
mathrm
matplotlib
maxval
meshgrid
mimread
mimsave
minval
modindex
myst
nabla
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added supervised coreset construction algorithm in `coreax.solvers.GreedyKernelPoints`
- Added `coreax.kernels.PowerKernel` to replace repeated calls of `coreax.kernels.ProductKernel`
within the `**` magic method of `coreax.kernel.ScalarValuedKernel`
- Added scalar-valued kernel functions `coreax.kernels.PoissonKernel` and `coreax.kernels.MaternKernel`


### Fixed
Expand Down
4 changes: 4 additions & 0 deletions coreax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
LaplacianKernel,
LinearKernel,
LocallyPeriodicKernel,
MaternKernel,
PCIMQKernel,
PeriodicKernel,
PoissonKernel,
PolynomialKernel,
RationalQuadraticKernel,
SquaredExponentialKernel,
Expand All @@ -54,4 +56,6 @@
"DuoCompositeKernel",
"UniCompositeKernel",
"PowerKernel",
"PoissonKernel",
"MaternKernel",
]
143 changes: 142 additions & 1 deletion coreax/kernels/scalar_valued.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import equinox as eqx
import jax.numpy as jnp
from jax import Array
from jax import Array, vmap
from jax.scipy.special import factorial
from jax.typing import ArrayLike
from typing_extensions import override

Expand Down Expand Up @@ -181,6 +182,146 @@ def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
return scale * k * (d - scale * squared_distance(x, y))


class PoissonKernel(ScalarValuedKernel):
r"""
pc532627 marked this conversation as resolved.
Show resolved Hide resolved
Define a Poisson kernel.

Given :math:`r=` ``index``, :math:`0 < r < 1`, and :math:`\rho =` ``output_scale``,
the Poisson kernel is defined as
:math:`k: [0, 2\pi) \times [0, 2\pi) \to \mathbb{R}`,
:math:`k(x, y) = \frac{\rho}{1 - 2r\cos(x-y) + r^2}`.

.. warning::
Unlike many other kernels in Coreax, the Poisson kernel is not defined on
arbitrary :math:`\mathbb{R}^d`, but instead a subset of the positive real line
:math:`[0, 2\pi)`. We do not check that inputs to methods in this class lie in
the correct domain, therefore unexpected behaviour may occur. For example,
passing :math:`n`-vectors to the `compute` method will be interpreted as one
observation of a `:math:`n`- dimensional vector, and not :math:`n` observations
of a one dimensional vector, and therefore would be an invalid use of this
kernel function.

:param index: Kernel parameter indexing the family of Poisson kernel functions
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""

index: float = eqx.field(default=0.5, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)

def __check_init__(self):
"""Check attributes are valid."""
if self.index <= 0 or self.index >= 1:
raise ValueError("'index' must be be between 0 and 1 exclusive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
return self.output_scale / (
1
- 2 * self.index * jnp.cos(jnp.linalg.norm(jnp.subtract(x, y)))
+ self.index**2
)

@override
def grad_x_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
return -self.grad_y_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
# Note that we do not take a norm here in order to maintain the dimensionality
# of the vectors x and y, this ensures calls to 'grad_y' and 'grad_x' have
# expected dimensionality.
distance = jnp.subtract(x, y)
return (2 * self.output_scale * self.index * jnp.sin(distance)) / (
1 - 2 * self.index * jnp.cos(distance) + self.index**2
) ** 2

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
distance = jnp.linalg.norm(jnp.subtract(x, y))
div = 1 - 2 * self.index * jnp.cos(distance) + self.index**2
first_term = (2 * self.output_scale * self.index * jnp.cos(distance)) / div**2
second_term = (
8 * self.output_scale * self.index**2 * jnp.sin(distance) ** 2
) / div**3
return first_term - second_term


class MaternKernel(ScalarValuedKernel):
r"""
Define Matérn kernel with smoothness parameter a multiple of :math:`\frac{1}{2}`.

Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
Matérn kernel with smoothness parameter :math:`\nu` set to be a multiple of
:math:`\frac{1}{2}`, i.e. :math:`\nu = p + \frac{1}{2}` where
:math:`p`=` ``degree`` `:`math:`\in\mathbb{N}`, is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,

.. math::
k(x, y) = \rho^2 * \exp\left(-\frac{\sqrt{2p+1}||x-y||}{\lambda}\right)
\frac{p!}{(2p)!}\sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}

where :math:`||\cdot||` is the usual :math:`L_2`-norm.

:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param degree: Kernel degree, :math:`p`, must be a non-negative integer
"""

length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
degree: int = 1

def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if not isinstance(self.degree, int) or self.degree < 0:
raise ValueError("'degree' must be a non-negative integer")

def _compute_summation_term(self, body: float, iteration: ArrayLike) -> Array:
r"""
Compute the summation term of the Matérn kernel for a given iteration.

Given :math:`p`=``degree``:math:`\in\mathbb{N}`, compute

.. math::
\sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}.

:param body: Float representing
:math:`\left(\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)`
:param iteration: Current iteration
"""
factorial_term = factorial(self.degree + iteration) / (
factorial(iteration) * factorial(self.degree - iteration)
)
distance_term = (2 * body) ** (self.degree - iteration)
return factorial_term * distance_term

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
norm = jnp.linalg.norm(jnp.subtract(x, y))
body = (jnp.sqrt(2 * self.degree + 1) * norm) / self.length_scale
factor = (
self.output_scale**2
* jnp.exp(-body)
* factorial(self.degree)
/ factorial(2 * self.degree)
)

summation = 1.0
if self.degree > 0:
pc532627 marked this conversation as resolved.
Show resolved Hide resolved
mapped_function = vmap(self._compute_summation_term, in_axes=(None, 0))
summation = mapped_function(body, jnp.arange(self.degree + 1)).sum()
return factor * summation


class ExponentialKernel(ScalarValuedKernel):
r"""
Define an exponential kernel.
Expand Down
Loading
Loading