Skip to content

Commit

Permalink
add test for varied chromatic index fourier bases
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy-baier committed Sep 26, 2024
1 parent 1adb01e commit 4136dc2
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/test_gp_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from enterprise.signals import gp_signals
from enterprise.signals import gp_priors
from enterprise.signals import gp_bases
from enterprise.signals import signal_base
from enterprise.signals import white_signals
from enterprise.signals import selections
import scipy.stats


Expand Down Expand Up @@ -294,3 +297,49 @@ def test_powerlaw_genmodes_prior(self):
# test shape
msg = "F matrix shape incorrect"
assert rnm.get_basis(params).shape == F.shape, msg

def test_chromatic_fourier_basis_varied_idx(self):
"""Test the set up of variable index chromatic bases and make sure that the caching is the same as no caching"""
idx = parameter.Uniform(2.5, 7)
uncached_basis = gp_bases.createfourierdesignmatrix_chromatic(self.psr.toas,
self.psr.freqs,
nmodes=100,
idx=idx)
fmat_red, Ffreqs, fref_over_radio_freqs = gp_bases.construct_chromatic_cached_parts(self.psr.toas,
self.psr.freqs,
nmodes=100)
cached_basis = gp_bases.createfourierdesignmatrix_chromatic_with_additional_caching(fmat_red=fmat_red,
Ffreqs=Ffreqs,
fref_over_radio_freqs=fref_over_radio_freqs,
idx=idx)
pr = gp_priors.powerlaw(log10_A=parameter.Uniform(-18, -11), gamma=parameter.Uniform(1, 7))
uncached = gp_signals.BasisGP(priorFunction=pr, basisFunction=uncached_basis, name="chrom_gp")
cached = gp_signals.BasisGP(priorFunction=pr, basisFunction=cached_basis, name="chrom_gp")
pr = gp_priors.powerlaw_genmodes(log10_A=parameter.Uniform(-18, -12), gamma=parameter.Uniform(1, 7))
basis = gp_bases.createfourierdesignmatrix_red(nmodes=30)
rn = gp_signals.BasisGP(priorFunction=pr, basisFunction=basis, name="red_noise")
efac = parameter.Normal(1.0, 0.1)
backend = selections.Selection(selections.by_backend)
equad = parameter.Uniform(-8.5, -5)
wn = white_signals.MeasurementNoise(efac=efac, log10_t2equad=equad,
selection=backend, name=None)
mod1 = uncached + rn + wn
mod2 = cached + rn + wn
uncached_pta = signal_base.PTA([mod1(self.psr)])
cached_pta = signal_base.PTA([mod2(self.psr)])

# check that both of the chromatic bases have the chromatic index as a parameter
msg = "chromatic index missing from pta parameter list"
assert "B1855+09_chrom_gp_idx" in uncached_pta.param_names, msg
assert "B1855+09_chrom_gp_idx" in cached_pta.param_names, msg

# test to make sure the likelihood evaluations agree for 10 calls
msg = "the likelihood from cached chromatic basis disagrees with the uncached chroamtic basis likelihood"
x0 = [np.hstack([p.sample() for p in cached_pta.params]) for _ in range(10)]
no_cache_lnlike = [uncached_pta.get_lnlikelihood(x0[i] ) for i in range(10)]
cache_lnlike = [cached_pta.get_lnlikelihood(x0[i] ) for i in range(10)]
assert np.all(no_cache_lnlike == cache_lnlike), msg

# check that both the cached and the uncached basis yield the same basis
msg = "the cached chromatic basis does not match the uncached chromatic basis"
assert np.all(uncached_pta.get_basis(params={})[0] == cached_pta.get_basis(params={})[0]), msg

0 comments on commit 4136dc2

Please sign in to comment.