Skip to content

Commit

Permalink
Merge pull request #320 from desihub/multitemplate
Browse files Browse the repository at this point in the history
support rrdesi --templates (list of templates)
  • Loading branch information
sbailey authored Nov 5, 2024
2 parents 0ba968f + 153ee37 commit 6e56d9b
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 5 deletions.
2 changes: 1 addition & 1 deletion py/redrock/external/desi.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def rrdesi(options=None, comm=None):
parser = argparse.ArgumentParser(description="Estimate redshifts from"
" DESI target spectra.")

parser.add_argument("-t", "--templates", type=str, default=None,
parser.add_argument("-t", "--templates", type=str, nargs='+', default=None,
required=False, help="template file or directory")

parser.add_argument("--archetypes", type=str, default=None,
Expand Down
9 changes: 8 additions & 1 deletion py/redrock/fitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,14 @@ def fitz(zchi2, redshifts, target, template, nminima=3, archetype=None, use_gpu=
ii = np.argsort([tmp['chi2'] for tmp in results])
results = [results[i] for i in ii]

assert len(results) > 0
if len(results) == 0:
#- Return blank arrays of 0 minima
xfloat = np.zeros((0,1), dtype=float)
xint = np.zeros((0,1), dtype=int)
xstr = np.zeros((0,1), dtype=str)
return dict(z=xfloat, zerr=xfloat, zwarn=xint, chi2=xfloat, zz=xfloat, zzchi2=xfloat,
coeff=xfloat, fitmethod=xstr, npixels=xint)

#- Convert list of dicts -> Table
#from astropy.table import Table
#results = Table(results)
Expand Down
3 changes: 3 additions & 0 deletions py/redrock/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def tophat(z, z0, s0):

prior = np.where(np.abs(z - z0) < s0/2, 0., np.NaN)

if np.all(np.isnan(prior)):
return prior

index_left, index_right = np.argwhere(prior>=0.0)[0], np.argwhere(prior>=0.0)[-1]

if index_left == 0:
Expand Down
12 changes: 10 additions & 2 deletions py/redrock/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,15 @@ def make_fulltype(spectype, subtype):
def find_templates(template_path=None):
"""Return list of Redrock template files
`template_path` can be one of 4 things:
`template_path` can be one of 5 things:
* None (use $RR_TEMPLATE_DIR instead)
* list/tuple/set/array of template files
* path to directory containing template files
* path to single template file to use
* path to text file listing which template files to use
* None (use $RR_TEMPLATE_DIR instead)
Returns list of template files to use
"""
if template_path is None:
if 'RR_TEMPLATE_DIR' in os.environ:
Expand All @@ -319,6 +322,11 @@ def find_templates(template_path=None):
else:
print(f'DEBUG: Reading templates from {template_path}')

# for symmetry with load_templates(template_path), also
# support list/tuple/etc of strings and just return that
if isinstance(template_path, (list, tuple, set, np.ndarray)):
return template_path

if os.path.isdir(template_path):
default_templates_file = f'{template_path}/templates-default.txt'
template_dir = template_path
Expand Down
5 changes: 5 additions & 0 deletions py/redrock/test/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def test_find_templates(self):
for filename in templates:
self.assertIn(os.path.basename(filename), self.default_templates)

# find templates also accepts a list of templates which it just returns
t2 = find_templates(templates[1:3])
self.assertListEqual(t2, templates[1:3])

# or read from text file a set of alternate templates
templates = find_templates(self.testTemplateDir+'/templates-alternate.txt')

# works without a path prefix
Expand Down
6 changes: 5 additions & 1 deletion py/redrock/zfind.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,10 @@ def zfind(targets, templates, mp_procs=1, nminima=3, archetypes=None, priors=Non
#np arrays instead of astropy Table
nmin = len(tmp['chi2'])

if nmin == 0:
print(f'WARNING: no {fulltype} chi2 vs. z minima for targetid {tid}')
continue

if np.isscalar(spectype):
tmp['spectype'] = np.full((nmin, 1), spectype)
tmp['subtype'] = np.full((nmin, 1), subtype)
Expand Down Expand Up @@ -551,7 +555,7 @@ def zfind(targets, templates, mp_procs=1, nminima=3, archetypes=None, priors=Non
ibad[k]=False

tzfit['zwarn'][ibad] |= ZW.NEGATIVE_MODEL

tzfit['zwarn'][ tzfit['npixels']==0 ] |= ZW.NODATA
tzfit['zwarn'][ (tzfit['npixels']<10*tzfit['ncoeff']) ] |= \
ZW.LITTLE_COVERAGE
Expand Down

0 comments on commit 6e56d9b

Please sign in to comment.