Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Jun 2, 2023
1 parent 6cfea77 commit 963c84c
Show file tree
Hide file tree
Showing 20 changed files with 263 additions and 53 deletions.
18 changes: 11 additions & 7 deletions gpytorch/kernels/additive_structure_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ class AdditiveStructureKernel(Kernel):
Passed down to the `base_kernel`.
"""

@property
def is_stationary(self) -> bool:
"""
Kernel is stationary if the base kernel is stationary.
"""
return self.base_kernel.is_stationary

def __init__(
self,
base_kernel: Kernel,
Expand All @@ -51,6 +44,17 @@ def __init__(
self.base_kernel = base_kernel
self.num_dims = num_dims

@property
def _lazily_evaluate(self) -> bool:
return self.base_kernel._lazily_evaluate

@property
def is_stationary(self) -> bool:
"""
Kernel is stationary if the base kernel is stationary.
"""
return self.base_kernel.is_stationary

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
if last_dim_is_batch:
raise RuntimeError("AdditiveStructureKernel does not accept the last_dim_is_batch argument.")
Expand Down
6 changes: 4 additions & 2 deletions gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class CosineKernel(Kernel):
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
"""

is_stationary = True

def __init__(
self,
period_length_prior: Optional[Prior] = None,
Expand Down Expand Up @@ -85,6 +83,10 @@ def __init__(

self.register_constraint("raw_period_length", period_length_constraint)

@property
def is_stationary(self):
return True

@property
def period_length(self):
return self.raw_period_length_constraint.transform(self.raw_period_length)
Expand Down
4 changes: 1 addition & 3 deletions gpytorch/kernels/cylindrical_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from .. import settings
from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel
Expand Down Expand Up @@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
else:
angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))

with settings.lazily_evaluate_kernels(False):
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params)
return radial_kernel.mul(angular_kernel)

def kuma(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
7 changes: 7 additions & 0 deletions gpytorch/kernels/grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def __init__(
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))

@property
def _lazily_evaluate(self) -> bool:
# GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel
# matrix always needs to be evaluated; regardless of the size of x1 and x2), and the
# InterpolatedLinearOperator structure is needed for fast predictions.
return False

@property
def _tight_grid_bounds(self):
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))
Expand Down
11 changes: 9 additions & 2 deletions gpytorch/kernels/grid_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class GridKernel(Kernel):
http://www.cs.cmu.edu/~andrewgw/manet.pdf
"""

is_stationary = True

def __init__(
self,
base_kernel: Kernel,
Expand All @@ -66,6 +64,15 @@ def __init__(
if not self.interpolation_mode:
self.register_buffer("full_grid", create_data_from_grid(grid))

@property
def _lazily_evaluate(self) -> bool:
# Toeplitz structure is very efficient; no need to lazily evaluate
return False

@property
def is_stationary(self) -> bool:
return True

def _clear_cache(self):
if hasattr(self, "_cached_kernel_mat"):
del self._cached_kernel_mat
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/kernels/index_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def __init__(

self.register_constraint("raw_var", var_constraint)

@property
def _lazily_evaluate(self) -> bool:
# IndexKernel does not need lazy evaluation, since the complete BB^T + D_v` is always
# computed regardless of x1 and x2
return False

@property
def var(self):
return self.raw_var_constraint.transform(self.raw_var)
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/kernels/inducing_point_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def _clear_cache(self):
if hasattr(self, "_cached_kernel_inv_root"):
del self._cached_kernel_inv_root

@property
def _lazily_evaluate(self) -> bool:
# InducingPointKernels kernels should not lazily evaluate; to use the Woodbury formula,
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
return False

@property
def _inducing_mat(self):
if not self.training and hasattr(self, "_cached_kernel_mat"):
Expand Down
4 changes: 2 additions & 2 deletions gpytorch/kernels/keops/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math

import torch
from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator

from ... import settings
from .keops_kernel import KeOpsKernel
Expand Down Expand Up @@ -92,7 +92,7 @@ def forward(self, x1, x2, diag=False, **params):
return self.covar_func(x1_, x2_, diag=True)

covar_func = lambda x1, x2, diag=False: self.covar_func(x1, x2, diag)
return KeOpsLinearOperator(x1_, x2_, covar_func)
return KernelLinearOperator(x1_, x2_, covar_func=covar_func)

except ImportError:

Expand Down
4 changes: 2 additions & 2 deletions gpytorch/kernels/keops/rbf_kernel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

import torch
from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator

from ... import settings
from ..rbf_kernel import postprocess_rbf
Expand Down Expand Up @@ -54,7 +54,7 @@ def forward(self, x1, x2, diag=False, **params):
if diag:
return covar_func(x1_, x2_, diag=True)

return KeOpsLinearOperator(x1_, x2_, covar_func)
return KernelLinearOperator(x1_, x2_, covar_func=covar_func)

except ImportError:

Expand Down
Loading

0 comments on commit 963c84c

Please sign in to comment.