Skip to content

Commit

Permalink
update apogeenet pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Jul 8, 2024
1 parent d02cba4 commit a9300c3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
12 changes: 3 additions & 9 deletions python/astra/pipelines/apogeenet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,30 +321,24 @@ def apogeenet(

log_G,log_Teff,FeH,log_G_std,log_Teff_std,Feh_std = make_prediction(flux, e_flux, None, num_uncertainty_draws,model,device)
except:
log.exception(f"Exception when running ANet on {spectrum}")
yield ANet(
log.exception(f"Exception when running ApogeeNet on {spectrum}")
yield ApogeeNet(
spectrum_pk=spectrum.spectrum_pk,
source_pk=spectrum.source_pk,
flag_runtime_exception=True
)
else:
yield ANet(
yield ApogeeNet(
spectrum_pk=spectrum.spectrum_pk,
source_pk=spectrum.source_pk,
fe_h=FeH,
e_fe_h=Feh_std,
logg=log_G,
e_logg=log_G_std,
teff=10**log_Teff,
#e_teff=10**log_Teff_std,
#e_teff=10^logteff*e_logteff*ln(10)
e_teff=10**log_Teff * log_Teff_std * np.log(10)
)

# to make correction in psql:
#update a_net set e_teff=teff * log10(e_teff) * 2.302585092994046 where task_pk = 1;
# then check, and apply to all with task_pk > 1


'''
path = expand_path(
Expand Down
17 changes: 14 additions & 3 deletions python/astra/pipelines/apogeenet_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
from typing import Iterable, Optional
from typing import Iterable, Optional, Union
from astra import task
from astra.models import ApogeeNetV2
from astra.models import ApogeeNetV2, ApogeeVisitSpectrumInApStar, ApogeeCoaddedSpectrumInApStar
from astra.pipelines.apogeenet_v2.network import read_network
from astra.pipelines.apogeenet_v2.base import _prepare_data, _worker, parallel_batch_read, _inference
from peewee import JOIN, ModelSelect

__all__ = ["apogeenet"]

@task
def apogeenet_v2(
spectra: Iterable,
spectra: Optional[Iterable[Union[ApogeeVisitSpectrumInApStar, ApogeeCoaddedSpectrumInApStar]]] = (
ApogeeCoaddedSpectrumInApStar
.select()
.join(ApogeeNetV2, JOIN.LEFT_OUTER, on=(ApogeeCoaddedSpectrumInApStar.spectrum_pk == ApogeeNetV2.spectrum_pk))
.where(ApogeeNetV2.spectrum_pk.is_null())
),
network_path: str = "$MWM_ASTRA/pipelines/APOGEENet/model.pt",
large_error: Optional[float] = 1e10,
num_uncertainty_draws: Optional[int] = 100,
parallel: Optional[bool] = False,
limit: Optional[int] = None,
**kwargs
) -> Iterable[ApogeeNetV2]:
"""
Estimate astrophysical parameters for a stellar spectrum given a pre-trained neural network.
"""

if isinstance(spectra, ModelSelect):
if limit is not None:
spectra = spectra.limit(limit)

network = read_network(network_path)

if parallel:
Expand Down
3 changes: 3 additions & 0 deletions python/astra/pipelines/apogeenet_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def _inference(network, batch, num_uncertainty_draws):
e_teff=teff_std[i],
e_logg=logg_std[i],
e_fe_h=fe_h_std[i],
raw_e_teff=teff_std[i],
raw_e_logg=logg_std[i],
raw_e_fe_h=fe_h_std[i],
teff_sample_median=teff_median[i],
logg_sample_median=logg_median[i],
fe_h_sample_median=fe_h_median[i],
Expand Down

0 comments on commit a9300c3

Please sign in to comment.