Skip to content

Commit

Permalink
add tests; fix wavemax vs. wavemin bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephen Bailey authored and Stephen Bailey committed Apr 25, 2024
1 parent d1e3929 commit cc3c2f4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
3 changes: 0 additions & 3 deletions py/redrock/archetypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@ def eval(self, subtype, dwave, coeff, wave, z):
deg_legendre = (coeff!=0.).size-1
index = np.arange(self._narch)[self._subtype==subtype][0]

#w = np.concatenate([ w for w in dwave.values() ])
#wave_min = w.min()
#wave_max = w.max()
legendre = np.array([scipy.special.legendre(i)(reduced_wavelength(w)) for i in range(deg_legendre)])
binned = trapz_rebin((1+z)*self.wave, self.flux[index], wave)*transmission_Lyman(z,wave,model=self.igm_model)
flux = np.append(binned[None,:],legendre, axis=0)
Expand Down
3 changes: 0 additions & 3 deletions py/redrock/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ def legendre(self, nleg, use_gpu=False):
if (self._legendre is not None and self.nleg == nleg):
return self._legendre
dwave = { s.wavehash:s.wave for s in self.spectra }
#wave = np.concatenate([ w for w in dwave.values() ])
#wmin = wave.min()
#wmax = wave.max()
self._legendre = { hs:np.array([scipy.special.legendre(i)(reduced_wavelength(w)) for i in range(nleg)]) for hs, w in dwave.items() }
self.nleg = nleg
if (use_gpu):
Expand Down
21 changes: 21 additions & 0 deletions py/redrock/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ def test_distribute_work(self):
dist = utils.distribute_work(nproc, ids, capacities=capacities)
self.assertEqual(list(map(len, dist)), [1, 3])

def test_reduced_wavelength(self):
x = utils.reduced_wavelength(np.arange(10))
self.assertEqual(x[0], -1.0)
self.assertEqual(x[-1], 1.0)
x = utils.reduced_wavelength(np.linspace(3600, 5800, 10))
self.assertEqual(x[0], -1.0)
self.assertEqual(x[-1], 1.0)

#- even out-of-order non-linear ok
x = utils.reduced_wavelength(np.random.uniform(-5, 20, size=100))
self.assertEqual(np.min(x), -1.0)
self.assertEqual(np.max(x), 1.0)

#- also works on cupy if installed,
#- and answer remains on GPU as a cupy array
if cp_available:
x = utils.reduced_wavelength(cp.linspace(3600, 5800, 10))
self.assertEqual(x[0], -1.0)
self.assertEqual(x[-1], 1.0)
self.assertTrue(isinstance(x, cp.ndarray))

def test_suite():
"""Allows testing of only this module with the command::
Expand Down
3 changes: 1 addition & 2 deletions py/redrock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def reduced_wavelength(wave):
Return:
reduced wavelength in [-1,1] range
"""
wave = np.asarray(wave)
wavemax = wave.max()
wavemin = wave.min()
return 2*(wave - wavemax) / (wavemax - wavemin) - 1
return 2*(wave - wavemin) / (wavemax - wavemin) - 1

0 comments on commit cc3c2f4

Please sign in to comment.