From cc3c2f4f05ef2b2ef928f00adac142c8bf96308f Mon Sep 17 00:00:00 2001 From: Stephen Bailey Date: Wed, 24 Apr 2024 17:06:03 -0700 Subject: [PATCH] add tests; fix wavemax vs. wavemin bug --- py/redrock/archetypes.py | 3 --- py/redrock/targets.py | 3 --- py/redrock/test/test_utils.py | 21 +++++++++++++++++++++ py/redrock/utils.py | 3 +-- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/py/redrock/archetypes.py b/py/redrock/archetypes.py index 708b48a0..0e4b7a97 100644 --- a/py/redrock/archetypes.py +++ b/py/redrock/archetypes.py @@ -157,9 +157,6 @@ def eval(self, subtype, dwave, coeff, wave, z): deg_legendre = (coeff!=0.).size-1 index = np.arange(self._narch)[self._subtype==subtype][0] - #w = np.concatenate([ w for w in dwave.values() ]) - #wave_min = w.min() - #wave_max = w.max() legendre = np.array([scipy.special.legendre(i)(reduced_wavelength(w)) for i in range(deg_legendre)]) binned = trapz_rebin((1+z)*self.wave, self.flux[index], wave)*transmission_Lyman(z,wave,model=self.igm_model) flux = np.append(binned[None,:],legendre, axis=0) diff --git a/py/redrock/targets.py b/py/redrock/targets.py index da0acac5..de4c39bb 100644 --- a/py/redrock/targets.py +++ b/py/redrock/targets.py @@ -180,9 +180,6 @@ def legendre(self, nleg, use_gpu=False): if (self._legendre is not None and self.nleg == nleg): return self._legendre dwave = { s.wavehash:s.wave for s in self.spectra } - #wave = np.concatenate([ w for w in dwave.values() ]) - #wmin = wave.min() - #wmax = wave.max() self._legendre = { hs:np.array([scipy.special.legendre(i)(reduced_wavelength(w)) for i in range(nleg)]) for hs, w in dwave.items() } self.nleg = nleg if (use_gpu): diff --git a/py/redrock/test/test_utils.py b/py/redrock/test/test_utils.py index 5039b00d..dc15bbd7 100644 --- a/py/redrock/test/test_utils.py +++ b/py/redrock/test/test_utils.py @@ -38,6 +38,27 @@ def test_distribute_work(self): dist = utils.distribute_work(nproc, ids, capacities=capacities) self.assertEqual(list(map(len, dist)), [1, 3]) + def test_reduced_wavelength(self): + x = utils.reduced_wavelength(np.arange(10)) + self.assertEqual(x[0], -1.0) + self.assertEqual(x[-1], 1.0) + x = utils.reduced_wavelength(np.linspace(3600, 5800, 10)) + self.assertEqual(x[0], -1.0) + self.assertEqual(x[-1], 1.0) + + #- even out-of-order non-linear ok + x = utils.reduced_wavelength(np.random.uniform(-5, 20, size=100)) + self.assertEqual(np.min(x), -1.0) + self.assertEqual(np.max(x), 1.0) + + #- also works on cupy if installed, + #- and answer remains on GPU as a cupy array + if cp_available: + x = utils.reduced_wavelength(cp.linspace(3600, 5800, 10)) + self.assertEqual(x[0], -1.0) + self.assertEqual(x[-1], 1.0) + self.assertTrue(isinstance(x, cp.ndarray)) + def test_suite(): """Allows testing of only this module with the command:: diff --git a/py/redrock/utils.py b/py/redrock/utils.py index 13989c20..0ae5f842 100644 --- a/py/redrock/utils.py +++ b/py/redrock/utils.py @@ -258,7 +258,6 @@ def reduced_wavelength(wave): Return: reduced wavelength in [-1,1] range """ - wave = np.asarray(wave) wavemax = wave.max() wavemin = wave.min() - return 2*(wave - wavemax) / (wavemax - wavemin) - 1 + return 2*(wave - wavemin) / (wavemax - wavemin) - 1