Skip to content

Commit e03f7f2

Browse files
effigiesjhlegarreta
authored andcommitted
type: Annotate Kernel.diag() argument X
Annotate `Kernel.diag()` argument `X`: use `npt.ArrayLike`. Fixes: ``` src/nifreeze/model/gpr.py:335: error: Argument 1 of "diag" is incompatible with supertype "Kernel"; supertype defines the argument type as "Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]" [override] src/nifreeze/model/gpr.py:335: note: This violates the Liskov substitution principle src/nifreeze/model/gpr.py:335: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides src/nifreeze/model/gpr.py:445: error: Argument 1 of "diag" is incompatible with supertype "Kernel"; supertype defines the argument type as "Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]" [override] src/nifreeze/model/gpr.py:445: note: This violates the Liskov substitution principle src/nifreeze/model/gpr.py:335: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:93 Documentation: https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
1 parent 26ec410 commit e03f7f2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/nifreeze/model/gpr.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from typing import Callable, ClassVar, Literal, Mapping, Optional, Sequence, Union
2929

3030
import numpy as np
31+
import numpy.typing as npt
3132
from scipy import optimize
3233
from scipy.optimize import Bounds
3334
from sklearn.gaussian_process import GaussianProcessRegressor
@@ -334,7 +335,7 @@ def __call__(
334335

335336
return self.beta_l * C_theta, K_gradient
336337

337-
def diag(self, X: np.ndarray) -> np.ndarray:
338+
def diag(self, X: npt.ArrayLike) -> np.ndarray:
338339
"""Returns the diagonal of the kernel k(X, X).
339340
340341
The result of this method is identical to np.diag(self(X)); however,
@@ -351,7 +352,7 @@ def diag(self, X: np.ndarray) -> np.ndarray:
351352
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
352353
Diagonal of kernel k(X, X)
353354
"""
354-
return self.beta_l * np.ones(X.shape[0])
355+
return self.beta_l * np.ones(np.asanyarray(X).shape[0])
355356

356357
def is_stationary(self) -> bool:
357358
"""Returns whether the kernel is stationary."""
@@ -444,7 +445,7 @@ def __call__(
444445

445446
return self.beta_l * C_theta, K_gradient
446447

447-
def diag(self, X: np.ndarray) -> np.ndarray:
448+
def diag(self, X: npt.ArrayLike) -> np.ndarray:
448449
"""Returns the diagonal of the kernel k(X, X).
449450
450451
The result of this method is identical to np.diag(self(X)); however,
@@ -461,7 +462,7 @@ def diag(self, X: np.ndarray) -> np.ndarray:
461462
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
462463
Diagonal of kernel k(X, X)
463464
"""
464-
return self.beta_l * np.ones(X.shape[0])
465+
return self.beta_l * np.ones(np.asanyarray(X).shape[0])
465466

466467
def is_stationary(self) -> bool:
467468
"""Returns whether the kernel is stationary."""

0 commit comments

Comments
 (0)