Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions pymc/gp/hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def prior(
self,
name: str,
X: TensorLike,
dims: str | None = None,
hsgp_coeffs_dims: str | None = None,
gp_dims: str | None = None,
*args,
**kwargs,
):
Expand All @@ -444,10 +444,11 @@ def prior(
Name of the random variable
X: array-like
Function input values.
dims: str, default None
Dimension name for the GP random variable.
hsgp_coeffs_dims: str, default None
Dimension name for the HSGP basis vectors.
gp_dims: str, default None
Dimension name for the GP random variable.

"""
phi, sqrt_psd = self.prior_linearized(X)
self._sqrt_psd = sqrt_psd
Expand All @@ -469,7 +470,7 @@ def prior(
)
f = self.mean_func(X) + phi @ self._beta

self.f = pm.Deterministic(name, f, dims=gp_dims)
self.f = pm.Deterministic(name, f, dims=dims)
return self.f

def _build_conditional(self, Xnew):
Expand Down Expand Up @@ -695,7 +696,9 @@ def prior_linearized(self, X: TensorLike):
psd = self.scale * self.cov_func.power_spectral_density_approx(J)
return (phi_cos, phi_sin), psd

def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ignore[override]
def prior( # type: ignore[override]
self, name: str, X: TensorLike, dims: str | None = None, hsgp_coeffs_dims: str | None = None
):
R"""
Return the (approximate) GP prior distribution evaluated over the input locations `X`.

Expand All @@ -709,11 +712,13 @@ def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ign
Function input values.
dims: None
Dimension name for the GP random variable.
hsgp_coeffs_dims: str | None = None
Dimension name for the HSGPPeriodic basis vectors.
"""
(phi_cos, phi_sin), psd = self.prior_linearized(X)

m = self._m
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2 - 1))
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2 - 1), dims=hsgp_coeffs_dims)
# The first eigenfunction for the sine component is zero
# and so does not contribute to the approximation.
f = (
Expand Down
Loading