Skip to content

Commit 2a3e420

Browse files
committed
Add test for covariance eigenvalue equivalence
1 parent 99d4616 commit 2a3e420

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

pymc/gp/cov.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,12 @@ def power_spectral_density(self, omega: TensorLike) -> TensorVariable:
651651

652652
z = pt.sqrt(2 * alpha) * pt.sqrt(pt.dot(pt.square(omega), pt.square(ls)))
653653
coeff = 2.0 * pt.power(2.0 * np.pi * alpha, D / 2.0) * pt.prod(ls) / pt.gamma(alpha)
654-
term_z = pt.power(z / 2.0, nu) * pt.kv(nu, z)
655-
654+
655+
# Handle singularity at z=0
656+
safe_z = pt.switch(pt.eq(z, 0), 1.0, z)
657+
term_z = pt.power(safe_z / 2.0, nu) * pt.kv(nu, safe_z)
658+
term_z = pt.switch(pt.eq(z, 0), pt.gamma(nu) / 2.0, term_z)
659+
656660
return coeff * term_z
657661

658662

tests/gp/test_cov.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytensor
1818
import pytensor.tensor as pt
1919
import pytest
20+
import scipy.linalg
2021

2122
from scipy.special import gamma, iv, kv
2223

@@ -553,6 +554,36 @@ def test_psd(self):
553554
)
554555
npt.assert_allclose(true_1d_psd, test_1d_psd, atol=1e-5)
555556

557+
def test_psd_eigenvalues(self):
558+
# Test PSD implementation using Szegő’s Theorem
559+
alpha = 1.5
560+
ls = 0.5
561+
N = 500
562+
L = 50.0
563+
dx = L / N
564+
X = np.linspace(0, L, N)[:, None]
565+
566+
with pm.Model():
567+
cov = pm.gp.cov.RatQuad(1, alpha=alpha, ls=ls)
568+
569+
K = cov(X).eval()
570+
571+
evals = scipy.linalg.eigvalsh(K)
572+
evals = np.sort(evals)[::-1]
573+
574+
freqs = np.fft.fftfreq(N, d=dx)
575+
omegas = 2 * np.pi * freqs
576+
577+
psd = cov.power_spectral_density(omegas[:, None]).eval()
578+
579+
psd_scaled = psd.flatten() / dx
580+
psd_sorted = np.sort(psd_scaled)[::-1]
581+
582+
rel_err = np.abs((evals - psd_sorted) / psd_sorted)
583+
med_rel_err = np.median(rel_err)
584+
585+
assert med_rel_err < 0.1
586+
556587

557588
class TestExponential:
558589
def test_1d(self):

0 commit comments

Comments
 (0)