Skip to content

Commit

Permalink
Just raise and error if training data not in altaz
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasBeiske committed Oct 27, 2023
1 parent b868cbd commit c83a9c6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 29 deletions.
32 changes: 5 additions & 27 deletions ctapipe/reco/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import astropy.units as u
import joblib
import numpy as np
from astropy.coordinates import AltAz, SkyCoord
from astropy.coordinates import AltAz
from astropy.table import QTable, Table, vstack
from astropy.utils.decorators import lazyproperty
from sklearn.metrics import accuracy_score, r2_score, roc_auc_score
Expand All @@ -22,7 +22,6 @@

from ..containers import (
ArrayEventContainer,
CoordinateFrameType,
DispContainer,
ParticleClassificationContainer,
ReconstructedEnergyContainer,
Expand Down Expand Up @@ -778,34 +777,13 @@ def predict_table(self, key, table: Table) -> Dict[ReconstructionProperty, Table
fov_lon = table["hillas_fov_lon"].quantity + disp * np.cos(psi)
fov_lat = table["hillas_fov_lat"].quantity + disp * np.sin(psi)

self.log.warning("FIXME: Assuming constant and parallel pointing for each run")
if np.all(table["subarray_pointing_frame"] is CoordinateFrameType.ALTAZ):
pointing_alt = table["subarray_pointing_lat"]
pointing_az = table["subarray_pointing_lon"]
elif np.all(table["subarray_pointing_frame"] is CoordinateFrameType.ICRS):
pointing_altaz = SkyCoord(
ra=table["subarray_pointing_lon"],
dec=table["subarray_pointing_lat"],
frame="icrs",
).transform_to(AltAz())
pointing_alt = pointing_altaz.alt
pointing_az = pointing_altaz.az
elif np.all(table["subarray_pointing_frame"] is CoordinateFrameType.GALACTIC):
pointing_altaz = SkyCoord(
l=table["subarray_pointing_lon"],
b=table["subarray_pointing_lat"],
frame="galactic",
).transform_to(AltAz())
pointing_alt = pointing_altaz.alt
pointing_az = pointing_altaz.az
else:
raise KeyError("Unknown observation coordinate frame")

# FIXME: Assume constant and parallel pointing for each run
self.log.warning("Assuming constant and parallel pointing for each run")
alt, az = telescope_to_horizontal(
lon=fov_lon,
lat=fov_lat,
pointing_alt=pointing_alt,
pointing_az=pointing_az,
pointing_alt=table["subarray_pointing_lat"],
pointing_az=table["subarray_pointing_lon"],
)

altaz_result = Table(
Expand Down
8 changes: 8 additions & 0 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import astropy.units as u
import numpy as np

from ctapipe.containers import CoordinateFrameType
from ctapipe.core import Tool
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
Expand Down Expand Up @@ -104,6 +105,13 @@ def start(self):
self.log.info("Loading events for %s", tel_type)
table = self._read_table(tel_type)

if not np.all(
table["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value
):
raise ValueError(
"Pointing information for training data has to be provided in horizontal coordinates"
)

self.log.info("Train models on %s events", len(table))
self.cross_validate(tel_type, table)

Expand Down
4 changes: 2 additions & 2 deletions docs/changes/2431.bugfix.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Check the coordinate frame in which the array pointing is given
before using it in ``DispReconstructor``.
Check that the array pointing is given in horizontal coordinates
before training a ``DispReconstructor``.

0 comments on commit c83a9c6

Please sign in to comment.