Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel function extension #558

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b9043fd
Initial kernel change for pytorch
Srceh Jul 15, 2022
b5f48c5
Change torch kernel-based methods to support new kernel behaviours.
Srceh Jul 21, 2022
27be93b
Initial TF implementation added.
Srceh Jul 25, 2022
e05ee05
Modify generic detector class and associated tests.
Srceh Jul 27, 2022
753ba72
Fixed prediction behaviour for torch gpu with new base kernel.
Srceh Jul 31, 2022
8832bef
Fixed feature dimension selection function.
Srceh Aug 5, 2022
d53984f
Added support to passing multiple kernel parameters. Doc string refin…
Srceh Aug 8, 2022
657e7b8
(1) refine various points according to the review. (2) re-design the …
Srceh Aug 18, 2022
8322330
revert mmd_cifar10 notebook
Srceh Aug 18, 2022
0199bac
This commit includes a major re-design of the base kernel class, it n…
Srceh Sep 4, 2022
3d235fb
Refine the behaviour of the new base kernel class, added further erro…
Srceh Sep 8, 2022
4587499
Added extra treatments for different kernel class. Also refine the ty…
Srceh Sep 20, 2022
bd1dde9
Add additional tests for the new kernels, and fix notebooks with new …
Srceh Oct 14, 2022
52486da
Address reviewer comments on: (1) doc string, (2) outdated comments, …
Srceh Oct 17, 2022
d6af592
Address some discussion and comments from the reviewer, mainly on : (…
Srceh Nov 11, 2022
43b0b4c
pre-rebase minor fixes.
Srceh Dec 15, 2022
335fe2c
Initial kernel change for pytorch
Srceh Jul 15, 2022
d059f65
Change torch kernel-based methods to support new kernel behaviours.
Srceh Jul 21, 2022
a843973
Initial TF implementation added.
Srceh Jul 25, 2022
50197ab
Modify generic detector class and associated tests.
Srceh Jul 27, 2022
db824e8
Fixed prediction behaviour for torch gpu with new base kernel.
Srceh Jul 31, 2022
d8c8083
Fixed feature dimension selection function.
Srceh Aug 5, 2022
3de8f98
Added support to passing multiple kernel parameters. Doc string refin…
Srceh Aug 8, 2022
0f63f61
(1) refine various points according to the review. (2) re-design the …
Srceh Aug 18, 2022
e9e9874
revert mmd_cifar10 notebook
Srceh Aug 18, 2022
61032e0
This commit includes a major re-design of the base kernel class, it n…
Srceh Sep 4, 2022
82d9dd4
Refine the behaviour of the new base kernel class, added further erro…
Srceh Sep 8, 2022
7e64b57
Added extra treatments for different kernel class. Also refine the ty…
Srceh Sep 20, 2022
1084180
Add additional tests for the new kernels, and fix notebooks with new …
Srceh Oct 14, 2022
8af1c51
Address reviewer comments on: (1) doc string, (2) outdated comments, …
Srceh Oct 17, 2022
bdad9d3
Address some discussion and comments from the reviewer, mainly on : (…
Srceh Nov 11, 2022
bca489b
Initial rebase with the current master.
Srceh Jan 3, 2023
be7d9fa
Initial integrate with the Keops and serialisation
Srceh Feb 1, 2023
6cc1798
Add serialisation for all new kernel classes.
Srceh Feb 24, 2023
1e381f4
Add support for serialisation of composite kernels.
Srceh Mar 20, 2023
b14db5a
Move composite kernel validation functions to loading module. Fixes f…
Srceh Mar 20, 2023
b24655e
(1) add 'kernel_list' key in config dict for better management. (2) m…
Srceh Mar 22, 2023
fb922a2
Fix dimension selection for TF.
Srceh Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def __init__(
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
# sigma: Optional[np.ndarray] = None,
Srceh marked this conversation as resolved.
Show resolved Hide resolved
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
input_shape: Optional[tuple] = None,
Expand Down Expand Up @@ -502,12 +502,13 @@ def __init__(
if p_val is None:
logger.warning('No p-value set for the drift threshold. Need to set it to detect data drift.')

self.infer_sigma = configure_kernel_from_x_ref
if configure_kernel_from_x_ref and isinstance(sigma, np.ndarray):
self.infer_sigma = False
logger.warning('`sigma` is specified for the kernel and `configure_kernel_from_x_ref` '
'is set to True. `sigma` argument takes priority over '
'`configure_kernel_from_x_ref` (set to False).')
self.infer_parameter = configure_kernel_from_x_ref
# self.infer_sigma = configure_kernel_from_x_ref
# if configure_kernel_from_x_ref and isinstance(sigma, np.ndarray):
# self.infer_sigma = False
# logger.warning('`sigma` is specified for the kernel and `configure_kernel_from_x_ref` '
# 'is set to True. `sigma` argument takes priority over '
# '`configure_kernel_from_x_ref` (set to False).')

# optionally already preprocess reference data
self.p_val = p_val
Expand Down Expand Up @@ -612,7 +613,8 @@ def __init__(
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
# kernel: BaseKernel = None,
Srceh marked this conversation as resolved.
Show resolved Hide resolved
# sigma: Optional[np.ndarray] = None,
n_permutations: int = 100,
n_kernel_centers: Optional[int] = None,
lambda_rd_max: float = 0.2,
Expand Down Expand Up @@ -665,7 +667,7 @@ def __init__(
self.x_ref = preprocess_fn(x_ref)
else:
self.x_ref = x_ref
self.sigma = sigma
# self.sigma = sigma
self.preprocess_x_ref = preprocess_x_ref
self.update_x_ref = update_x_ref
self.preprocess_fn = preprocess_fn
Expand Down
7 changes: 5 additions & 2 deletions alibi_detect/cd/context_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from typing import Callable, Dict, Optional, Union, Tuple
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
from alibi_detect.utils.pytorch.kernels import BaseKernel

if has_pytorch:
from alibi_detect.cd.pytorch.context_aware import ContextMMDDriftTorch
Expand All @@ -22,8 +23,10 @@ def __init__(
preprocess_x_ref: bool = True,
update_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
x_kernel: Callable = None,
c_kernel: Callable = None,
# x_kernel: Callable = None,
x_kernel: BaseKernel = None,
# c_kernel: Callable = None,
c_kernel: BaseKernel = None,
Srceh marked this conversation as resolved.
Show resolved Hide resolved
n_permutations: int = 1000,
prop_c_held: float = 0.25,
n_folds: int = 5,
Expand Down
3 changes: 2 additions & 1 deletion alibi_detect/cd/lsdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
# sigma: Optional[np.ndarray] = None,
# kernel: BaseKernel = None,
Srceh marked this conversation as resolved.
Show resolved Hide resolved
n_permutations: int = 100,
n_kernel_centers: Optional[int] = None,
lambda_rd_max: float = 0.2,
Expand Down
3 changes: 2 additions & 1 deletion alibi_detect/cd/lsdd_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
window_size: int,
backend: str = 'tensorflow',
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
# sigma: Optional[np.ndarray] = None,
# kernel: BaseKernel = None,
n_bootstraps: int = 1000,
n_kernel_centers: Optional[int] = None,
lambda_rd_max: float = 0.2,
Expand Down
4 changes: 2 additions & 2 deletions alibi_detect/cd/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = None,
sigma: Optional[np.ndarray] = None,
# sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
device: Optional[str] = None,
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
from alibi_detect.utils.tensorflow.kernels import GaussianRBF
else:
from alibi_detect.utils.pytorch.kernels import GaussianRBF # type: ignore
kwargs.update({'kernel': GaussianRBF})
kwargs.update({'kernel': GaussianRBF()})

if backend == 'tensorflow' and has_tensorflow:
kwargs.pop('device', None)
Expand Down
8 changes: 5 additions & 3 deletions alibi_detect/cd/mmd_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

if has_pytorch:
from alibi_detect.cd.pytorch.mmd_online import MMDDriftOnlineTorch
from alibi_detect.utils.pytorch.kernels import BaseKernel as BaseKernelTorch

if has_tensorflow:
from alibi_detect.cd.tensorflow.mmd_online import MMDDriftOnlineTF
from alibi_detect.utils.tensorflow.kernels import BaseKernel as BaseKernelTF


class MMDDriftOnline:
Expand All @@ -17,8 +19,8 @@ def __init__(
window_size: int,
backend: str = 'tensorflow',
preprocess_fn: Optional[Callable] = None,
kernel: Callable = None,
sigma: Optional[np.ndarray] = None,
kernel: Union[BaseKernelTorch, BaseKernelTF] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mauicv @ascillitoe I feel like this is confusing for typing purposes and somehow unnecessarily indicates coupling between kernel interfaces and backends. Feel like there should be a top-level class that is backend-agnostic for typing purposes, perhaps a Protocol is valid, although not if we have to rely on inheriting from BaseKernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial idea for simplifying serialisation and perhaps the optional deps stuff was to look at backend-agnostic parent classes for kernels and even preprocessing (i.e. replace the combination of preprocess_drift and partial with a preprocessing class). These backend-agnostic classes would have get_config and from_config methods, and also do the dispatching to the relevant backend-specific classes. That being said I didn't get around to doing any sort of POC so it might be a non-starter... Doing something similar but only for typing, via a Protocol, sounds like an interesting one...

# sigma: Optional[np.ndarray] = None,
n_bootstraps: int = 1000,
device: Optional[str] = None,
verbose: bool = True,
Expand Down Expand Up @@ -82,7 +84,7 @@ def __init__(
from alibi_detect.utils.tensorflow.kernels import GaussianRBF
else:
from alibi_detect.utils.pytorch.kernels import GaussianRBF # type: ignore
kwargs.update({'kernel': GaussianRBF})
kwargs.update({'kernel': GaussianRBF()})

if backend == 'tensorflow' and has_tensorflow:
kwargs.pop('device', None)
Expand Down
60 changes: 32 additions & 28 deletions alibi_detect/cd/pytorch/context_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,36 @@
from typing import Callable, Dict, Optional, Tuple, Union
from alibi_detect.cd.base import BaseContextMMDDrift
from alibi_detect.utils.pytorch import get_device
from alibi_detect.utils.pytorch.kernels import GaussianRBF
from alibi_detect.utils.pytorch.kernels import BaseKernel, GaussianRBF
from alibi_detect.cd._domain_clf import _SVCDomainClf
from tqdm import tqdm

logger = logging.getLogger(__name__)


def _sigma_median_diag(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
Srceh marked this conversation as resolved.
Show resolved Hide resolved
"""
Private version of the bandwidth estimation function :py:func:`~alibi_detect.utils.pytorch.kernels.sigma_median`,
with the +n (and -1) term excluded to account for the diagonal of the kernel matrix.

Parameters
----------
x
Tensor of instances with dimension [Nx, features].
y
Tensor of instances with dimension [Ny, features].
dist
Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`.

Returns
-------
The computed bandwidth, `sigma`.
"""
n_median = np.prod(dist.shape) // 2
sigma = (.5 * dist.flatten().sort().values[n_median].unsqueeze(dim=-1)) ** .5
return sigma


class ContextMMDDriftTorch(BaseContextMMDDrift):
lams: Optional[Tuple[torch.Tensor, torch.Tensor]] = None

Expand All @@ -22,8 +45,10 @@ def __init__(
preprocess_x_ref: bool = True,
update_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
x_kernel: Callable = GaussianRBF,
c_kernel: Callable = GaussianRBF,
# x_kernel: Callable = GaussianRBF,
x_kernel: BaseKernel = GaussianRBF(init_fn_sigma=_sigma_median_diag),
Srceh marked this conversation as resolved.
Show resolved Hide resolved
# c_kernel: Callable = GaussianRBF,
c_kernel: BaseKernel = GaussianRBF(init_fn_sigma=_sigma_median_diag),
n_permutations: int = 1000,
prop_c_held: float = 0.25,
n_folds: int = 5,
Expand Down Expand Up @@ -98,8 +123,10 @@ def __init__(
self.device = get_device(device)

# initialize kernel
self.x_kernel = x_kernel(init_sigma_fn=_sigma_median_diag) if x_kernel == GaussianRBF else x_kernel
self.c_kernel = c_kernel(init_sigma_fn=_sigma_median_diag) if c_kernel == GaussianRBF else c_kernel
# self.x_kernel = x_kernel(init_sigma_fn=_sigma_median_diag) if x_kernel == GaussianRBF else x_kernel
# self.c_kernel = c_kernel(init_sigma_fn=_sigma_median_diag) if c_kernel == GaussianRBF else c_kernel
self.x_kernel = x_kernel
self.c_kernel = c_kernel

# Initialize classifier (hardcoded for now)
self.clf = _SVCDomainClf(self.c_kernel)
Expand Down Expand Up @@ -244,26 +271,3 @@ def _pick_lam(self, lams: torch.Tensor, K: torch.Tensor, L: torch.Tensor, n_fold
kxx = torch.ones_like(lWk).to(lWk.device) * torch.max(K)
losses += (lWKWl + kxx - 2*lWk).sum(-1)
return lams[torch.argmin(losses)]


def _sigma_median_diag(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
"""
Private version of the bandwidth estimation function :py:func:`~alibi_detect.utils.pytorch.kernels.sigma_median`,
with the +n (and -1) term excluded to account for the diagonal of the kernel matrix.

Parameters
----------
x
Tensor of instances with dimension [Nx, features].
y
Tensor of instances with dimension [Ny, features].
dist
Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`.

Returns
-------
The computed bandwidth, `sigma`.
"""
n_median = np.prod(dist.shape) // 2
sigma = (.5 * dist.flatten().sort().values[n_median].unsqueeze(dim=-1)) ** .5
return sigma
25 changes: 14 additions & 11 deletions alibi_detect/cd/pytorch/lsdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
# sigma: Optional[np.ndarray] = None,
# kernel: BaseKernel = GaussianRBF(),
n_permutations: int = 100,
n_kernel_centers: Optional[int] = None,
lambda_rd_max: float = 0.2,
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(
preprocess_x_ref=preprocess_x_ref,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
# sigma=sigma,
# kernel=kernel,
n_permutations=n_permutations,
n_kernel_centers=n_kernel_centers,
lambda_rd_max=lambda_rd_max,
Expand All @@ -83,24 +85,25 @@ def __init__(
# in the method signature, so we can't cast it to torch.Tensor unless we change the signature
# to also accept torch.Tensor. We also can't redefine it's type as that would involve enabling
# --allow-redefinitions in mypy settings (which we might do eventually).
self.kernel = GaussianRBF()
if self.preprocess_x_ref or self.preprocess_fn is None:
x_ref = torch.as_tensor(self.x_ref).to(self.device) # type: ignore[assignment]
self._configure_normalization(x_ref) # type: ignore[arg-type]
x_ref = self._normalize(x_ref)
self._initialize_kernel(x_ref) # type: ignore[arg-type]
# self._initialize_kernel(x_ref) # type: ignore[arg-type]
self._configure_kernel_centers(x_ref) # type: ignore[arg-type]
self.x_ref = x_ref.cpu().numpy() # type: ignore[union-attr]
# For stability in high dimensions we don't divide H by (pi*sigma^2)^(d/2)
# Results in an alternative test-stat of LSDD*(pi*sigma^2)^(d/2). Same p-vals etc.
self.H = GaussianRBF(np.sqrt(2.) * self.kernel.sigma)(self.kernel_centers, self.kernel_centers)

def _initialize_kernel(self, x_ref: torch.Tensor):
if self.sigma is None:
self.kernel = GaussianRBF()
_ = self.kernel(x_ref, x_ref, infer_sigma=True)
else:
sigma = torch.from_numpy(self.sigma)
self.kernel = GaussianRBF(sigma)
# def _initialize_kernel(self, x_ref: torch.Tensor):
# if self.sigma is None:
# self.kernel = GaussianRBF()
# _ = self.kernel(x_ref, x_ref, infer_sigma=True)
# else:
# sigma = torch.from_numpy(self.sigma)
# self.kernel = GaussianRBF(sigma)

def _configure_normalization(self, x_ref: torch.Tensor, eps: float = 1e-12):
x_ref_means = x_ref.mean(0)
Expand Down Expand Up @@ -140,7 +143,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:
if self.preprocess_fn is not None and self.preprocess_x_ref is False:
self._configure_normalization(x_ref) # type: ignore[arg-type]
x_ref = self._normalize(x_ref)
self._initialize_kernel(x_ref) # type: ignore[arg-type]
# self._initialize_kernel(x_ref) # type: ignore[arg-type]
self._configure_kernel_centers(x_ref) # type: ignore[arg-type]
self.H = GaussianRBF(np.sqrt(2.) * self.kernel.sigma)(self.kernel_centers, self.kernel_centers)

Expand Down
23 changes: 13 additions & 10 deletions alibi_detect/cd/pytorch/lsdd_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Any, Callable, Optional, Union
from alibi_detect.cd.base_online import BaseMultiDriftOnline
from alibi_detect.utils.pytorch import get_device
from alibi_detect.utils.pytorch import GaussianRBF, permed_lsdds, quantile
from alibi_detect.utils.pytorch import permed_lsdds, quantile
from alibi_detect.utils.pytorch.kernels import GaussianRBF


class LSDDDriftOnlineTorch(BaseMultiDriftOnline):
Expand All @@ -14,7 +15,8 @@ def __init__(
ert: float,
window_size: int,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
# sigma: Optional[np.ndarray] = None,
# kernel: BaseKernel = GaussianRBF(),
n_bootstraps: int = 1000,
n_kernel_centers: Optional[int] = None,
lambda_rd_max: float = 0.2,
Expand Down Expand Up @@ -86,14 +88,15 @@ def __init__(
self._configure_normalization()

# initialize kernel
if sigma is None:
x_ref = torch.from_numpy(self.x_ref).to(self.device) # type: ignore[assignment]
self.kernel = GaussianRBF()
_ = self.kernel(x_ref, x_ref, infer_sigma=True)
else:
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = GaussianRBF(sigma) # type: ignore[arg-type]
# if sigma is None:
# x_ref = torch.from_numpy(self.x_ref).to(self.device) # type: ignore[assignment]
# self.kernel = GaussianRBF()
# _ = self.kernel(x_ref, x_ref, infer_sigma=True)
# else:
# sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
# np.ndarray) else None
# self.kernel = GaussianRBF(sigma) # type: ignore[arg-type]
self.kernel = GaussianRBF()

if self.n_kernel_centers is None:
self.n_kernel_centers = 2 * window_size
Expand Down
27 changes: 15 additions & 12 deletions alibi_detect/cd/pytorch/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from alibi_detect.cd.base import BaseMMDDrift
from alibi_detect.utils.pytorch import get_device
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix
from alibi_detect.utils.pytorch.kernels import GaussianRBF
from alibi_detect.utils.pytorch.kernels import BaseKernel, GaussianRBF

logger = logging.getLogger(__name__)

Expand All @@ -18,8 +18,9 @@ def __init__(
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = GaussianRBF,
sigma: Optional[np.ndarray] = None,
# kernel: Callable = GaussianRBF,
kernel: BaseKernel = GaussianRBF(),
# sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
device: Optional[str] = None,
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
preprocess_x_ref=preprocess_x_ref,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
# sigma=sigma,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
n_permutations=n_permutations,
input_shape=input_shape,
Expand All @@ -78,21 +79,23 @@ def __init__(
self.device = get_device(device)

# initialize kernel
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel
# sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
# np.ndarray) else None
# self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel
self.kernel = kernel

# compute kernel matrix for the reference data
if self.infer_sigma or isinstance(sigma, torch.Tensor):
# if self.infer_sigma or isinstance(sigma, torch.Tensor):
if self.infer_parameter:
x = torch.from_numpy(self.x_ref).to(self.device)
self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma)
self.infer_sigma = False
self.k_xx = self.kernel(x, x, infer_parameter=self.infer_parameter)
self.infer_parameter = False
else:
self.k_xx, self.infer_sigma = None, True
self.k_xx, self.infer_parameter = None, True

def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
""" Compute and return full kernel matrix between arrays x and y. """
k_xy = self.kernel(x, y, self.infer_sigma)
k_xy = self.kernel(x, y, self.infer_parameter)
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
k_yy = self.kernel(y, y)
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
Expand Down
Loading