Skip to content

Commit

Permalink
lint tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy-baier committed Sep 26, 2024
1 parent 4136dc2 commit 88e9bfa
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions tests/test_gp_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,21 +297,17 @@ def test_powerlaw_genmodes_prior(self):
# test shape
msg = "F matrix shape incorrect"
assert rnm.get_basis(params).shape == F.shape, msg

def test_chromatic_fourier_basis_varied_idx(self):
"""Test the set up of variable index chromatic bases and make sure that the caching is the same as no caching"""
idx = parameter.Uniform(2.5, 7)
uncached_basis = gp_bases.createfourierdesignmatrix_chromatic(self.psr.toas,
self.psr.freqs,
nmodes=100,
idx=idx)
fmat_red, Ffreqs, fref_over_radio_freqs = gp_bases.construct_chromatic_cached_parts(self.psr.toas,
self.psr.freqs,
nmodes=100)
cached_basis = gp_bases.createfourierdesignmatrix_chromatic_with_additional_caching(fmat_red=fmat_red,
Ffreqs=Ffreqs,
fref_over_radio_freqs=fref_over_radio_freqs,
idx=idx)
uncached_basis = gp_bases.createfourierdesignmatrix_chromatic(
self.psr.toas, self.psr.freqs, nmodes=100, idx=idx
)
fmat_red, Ffreqs, nus = gp_bases.construct_chromatic_cached_parts(self.psr.toas, self.psr.freqs, nmodes=100)
cached_basis = gp_bases.createfourierdesignmatrix_chromatic_with_additional_caching(
fmat_red=fmat_red, Ffreqs=Ffreqs, fref_over_radio_freqs=nus, idx=idx
)
pr = gp_priors.powerlaw(log10_A=parameter.Uniform(-18, -11), gamma=parameter.Uniform(1, 7))
uncached = gp_signals.BasisGP(priorFunction=pr, basisFunction=uncached_basis, name="chrom_gp")
cached = gp_signals.BasisGP(priorFunction=pr, basisFunction=cached_basis, name="chrom_gp")
Expand All @@ -321,23 +317,22 @@ def test_chromatic_fourier_basis_varied_idx(self):
efac = parameter.Normal(1.0, 0.1)
backend = selections.Selection(selections.by_backend)
equad = parameter.Uniform(-8.5, -5)
wn = white_signals.MeasurementNoise(efac=efac, log10_t2equad=equad,
selection=backend, name=None)
wn = white_signals.MeasurementNoise(efac=efac, log10_t2equad=equad, selection=backend, name=None)
mod1 = uncached + rn + wn
mod2 = cached + rn + wn
uncached_pta = signal_base.PTA([mod1(self.psr)])
cached_pta = signal_base.PTA([mod2(self.psr)])

# check that both of the chromatic bases have the chromatic index as a parameter
msg = "chromatic index missing from pta parameter list"
assert "B1855+09_chrom_gp_idx" in uncached_pta.param_names, msg
assert "B1855+09_chrom_gp_idx" in cached_pta.param_names, msg

# test to make sure the likelihood evaluations agree for 10 calls
msg = "the likelihood from cached chromatic basis disagrees with the uncached chroamtic basis likelihood"
x0 = [np.hstack([p.sample() for p in cached_pta.params]) for _ in range(10)]
no_cache_lnlike = [uncached_pta.get_lnlikelihood(x0[i] ) for i in range(10)]
cache_lnlike = [cached_pta.get_lnlikelihood(x0[i] ) for i in range(10)]
no_cache_lnlike = [uncached_pta.get_lnlikelihood(x0[i]) for i in range(10)]
cache_lnlike = [cached_pta.get_lnlikelihood(x0[i]) for i in range(10)]
assert np.all(no_cache_lnlike == cache_lnlike), msg

# check that both the cached and the uncached basis yield the same basis
Expand Down

0 comments on commit 88e9bfa

Please sign in to comment.