Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC, WIP]: Remove evaluate kernel #2342

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,12 @@ Composition/Decoration Kernels
.. autoclass:: MultiDeviceKernel
:members:

:hidden:`AdditiveStructureKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdditiveStructureKernel
:members:

:hidden:`ProductKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ProductKernel
:members:

:hidden:`ProductStructureKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ProductStructureKernel
:members:

:hidden:`ScaleKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 0 additions & 6 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
from . import keops
from .additive_structure_kernel import AdditiveStructureKernel
from .arc_kernel import ArcKernel
from .constant_kernel import ConstantKernel
from .cosine_kernel import CosineKernel
Expand All @@ -19,12 +18,10 @@
from .matern_kernel import MaternKernel
from .multi_device_kernel import MultiDeviceKernel
from .multitask_kernel import MultitaskKernel
from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel
from .periodic_kernel import PeriodicKernel
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
from .polynomial_kernel import PolynomialKernel
from .polynomial_kernel_grad import PolynomialKernelGrad
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .rbf_kernel_gradgrad import RBFKernelGradGrad
Expand All @@ -39,7 +36,6 @@
"Kernel",
"ArcKernel",
"AdditiveKernel",
"AdditiveStructureKernel",
"ConstantKernel",
"CylindricalKernel",
"MultiDeviceKernel",
Expand All @@ -55,13 +51,11 @@
"LinearKernel",
"MaternKernel",
"MultitaskKernel",
"NewtonGirardAdditiveKernel",
"PeriodicKernel",
"PiecewisePolynomialKernel",
"PolynomialKernel",
"PolynomialKernelGrad",
"ProductKernel",
"ProductStructureKernel",
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
Expand Down
73 changes: 0 additions & 73 deletions gpytorch/kernels/additive_structure_kernel.py

This file was deleted.

10 changes: 0 additions & 10 deletions gpytorch/kernels/constant_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,18 @@ def forward(
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
) -> Tensor:
"""Evaluates the constant kernel.

Args:
x1: First input tensor of shape (batch_shape x n1 x d).
x2: Second input tensor of shape (batch_shape x n2 x d).
diag: If True, returns the diagonal of the covariance matrix.
last_dim_is_batch: If True, the last dimension of size `d` of the input
tensors are treated as a batch dimension.

Returns:
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
constant covariance values if diag is False, resp. True.
"""
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

dtype = torch.promote_types(x1.dtype, x2.dtype)
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
Expand All @@ -117,7 +110,4 @@ def forward(
if not diag:
constant = constant.unsqueeze(-1)

if last_dim_is_batch:
constant = constant.unsqueeze(-1)

return constant.expand(shape)
6 changes: 4 additions & 2 deletions gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class CosineKernel(Kernel):
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
"""

is_stationary = True

def __init__(
self,
period_length_prior: Optional[Prior] = None,
Expand Down Expand Up @@ -85,6 +83,10 @@ def __init__(

self.register_constraint("raw_period_length", period_length_constraint)

@property
def is_stationary(self):
return True

@property
def period_length(self):
return self.raw_period_length_constraint.transform(self.raw_period_length)
Expand Down
4 changes: 1 addition & 3 deletions gpytorch/kernels/cylindrical_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from .. import settings
from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel
Expand Down Expand Up @@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
else:
angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))

with settings.lazily_evaluate_kernels(False):
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params)
return radial_kernel.mul(angular_kernel)

def kuma(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
79 changes: 40 additions & 39 deletions gpytorch/kernels/grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python3

from typing import List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union

import torch
from jaxtyping import Float
from linear_operator import to_linear_operator
from linear_operator.operators import InterpolatedLinearOperator
from linear_operator.operators import InterpolatedLinearOperator, LinearOperator
from torch import Tensor

from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
from ..utils.grid import create_grid
Expand All @@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel):
.. math::

\begin{equation*}
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z} \mathbf{w_{x_2}}
\end{equation*}

where

* :math:`U` is the set of gridded inducing points
* :math:`\boldsymbol Z` is the set of gridded inducing points

* :math:`K_{U,U}` is the kernel matrix between the inducing points
* :math:`K_{\boldsymbol Z, \boldsymbol Z}` is the kernel matrix between the inducing points

* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
Expand All @@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel):
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
Periodic, Spectral Mixture, etc.)

Args:
base_kernel (Kernel):
The kernel to approximate with KISS-GP
grid_size (Union[int, List[int]]):
The size of the grid in each dimension.
If a single int is provided, then every dimension will have the same grid size.
num_dims (int):
The dimension of the input data. Required if `grid_bounds=None`
grid_bounds (tuple(float, float), optional):
The bounds of the grid, if known (high performance mode).
The length of the tuple must match the number of dimensions.
The entries represent the min/max values for each dimension.
active_dims (tuple of ints, optional):
Passed down to the `base_kernel`.
:param base_kernel: The kernel to approximate with KISS-GP.
:param grid_size: The size of the grid in each dimension.
If a single int is provided, then every dimension will have the same grid size.
:param num_dims: The dimension of the input data. Required if `grid_bounds=None`
:param grid_bounds: The bounds of the grid, if known (high performance mode).
The length of the tuple must match the number of dimensions.
The entries represent the min/max values for each dimension.

.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
http://proceedings.mlr.press/v37/wilson15.pdf
Expand All @@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel):
def __init__(
self,
base_kernel: Kernel,
grid_size: Union[int, List[int]],
grid_size: Union[int, Iterable[int]],
num_dims: Optional[int] = None,
grid_bounds: Optional[Tuple[float, float]] = None,
active_dims: Optional[Tuple[int, ...]] = None,
**kwargs,
):
has_initialized_grid = 0
grid_is_dynamic = True
Expand Down Expand Up @@ -116,11 +111,17 @@ def __init__(
super(GridInterpolationKernel, self).__init__(
base_kernel=base_kernel,
grid=grid,
interpolation_mode=True,
active_dims=active_dims,
**kwargs,
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))

@property
def _lazily_evaluate(self) -> bool:
# GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel
# matrix always needs to be evaluated; regardless of the size of x1 and x2), and the
# InterpolatedLinearOperator structure is needed for fast predictions.
return False

@property
def _tight_grid_bounds(self):
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))
Expand All @@ -129,23 +130,26 @@ def _tight_grid_bounds(self):
for bound, spacing in zip(self.grid_bounds, grid_spacings)
)

def _compute_grid(self, inputs, last_dim_is_batch=False):
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
if last_dim_is_batch:
inputs = inputs.transpose(-1, -2).unsqueeze(-1)
n_dimensions = 1
batch_shape = inputs.shape[:-2]

def _compute_grid(self, inputs):
*batch_shape, n_data, n_dimensions = inputs.shape
inputs = inputs.reshape(-1, n_dimensions)
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
interp_values = interp_values.view(*batch_shape, n_data, -1)
return interp_indices, interp_values

def _inducing_forward(self, last_dim_is_batch, **params):
return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
def _create_or_update_full_grid(self, grid: Iterable[Tensor]):
pass

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
def _validate_inputs(self, x: Float[Tensor, "... N D"]) -> bool:
return True

def _inducing_forward(self, **params):
return super().forward(None, None, **params)

def forward(
self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"], diag: bool = False, **params
) -> Float[Union[Tensor, LinearOperator], "... N_1 N_2"]:
# See if we need to update the grid or not
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
if torch.equal(x1, x2):
Expand Down Expand Up @@ -180,16 +184,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
)
self.update_grid(grid)

base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)

left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
base_lazy_tsr = to_linear_operator(self._inducing_forward(**params))
left_interp_indices, left_interp_values = self._compute_grid(x1)
if torch.equal(x1, x2):
right_interp_indices = left_interp_indices
right_interp_values = left_interp_values
else:
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
right_interp_indices, right_interp_values = self._compute_grid(x2)

batch_shape = torch.broadcast_shapes(
base_lazy_tsr.batch_shape,
Expand Down
Loading
Loading