From bb7cc84150e4a60cf0a3681aad93798c29f0c404 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Wed, 25 Oct 2023 16:19:19 +0200 Subject: [PATCH 01/13] Load training data in chunks when training an energy regressor --- ctapipe/tools/train_energy_regressor.py | 72 ++++++++++++++++++------- 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 89647f6c39b..e0f412f169b 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -2,6 +2,8 @@ Tool for training the EnergyRegressor """ import numpy as np +from astropy.table import vstack +from tqdm.auto import tqdm from ctapipe.core import Tool from ctapipe.core.traits import Int, IntTelescopeParameter, Path @@ -53,6 +55,12 @@ class TrainEnergyRegressor(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help="How many subarray events to load at once before training on n_events.", + ).tag(config=True) + random_seed = Int( default_value=0, help="Random seed for sampling and cross validation" ).tag(config=True) @@ -61,6 +69,7 @@ class TrainEnergyRegressor(Tool): ("i", "input"): "TableLoader.input_url", ("o", "output"): "TrainEnergyRegressor.output_path", "n-events": "TrainEnergyRegressor.n_events", + "chunk-size": "TrainEnergyRegressor.chunk_size", "cv-output": "CrossValidator.output_path", } @@ -113,31 +122,56 @@ def start(self): self.log.info("done") def _read_table(self, telescope_type): - table = self.loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" - ) + chunk_iterator = self.loader.read_telescope_events_chunked( + self.chunk_size, + telescopes=[telescope_type], + ) + bar = tqdm( + chunk_iterator, + desc=f"Loading training events for {telescope_type}", + unit=" Array Events", + total=chunk_iterator.n_total, + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + + with bar: + for chunk, (start, stop, table_chunk) in enumerate(chunk_iterator): + self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + mask = self.regressor.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + self.log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), + ) + n_valid_events_in_file += len(table_chunk) + + table_chunk = self.regressor.feature_generator( + table_chunk, subarray=self.loader.subarray + ) + feature_names = self.regressor.features + [self.regressor.target] + table_chunk = table_chunk[feature_names] + + valid = check_valid_rows(table_chunk) + if np.any(~valid): + self.log.warning("Dropping non-predictable events.") + table_chunk = table_chunk[valid] - mask = self.regressor.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) + table.append(table_chunk) + bar.update(stop - start) + + table = vstack(table) + self.log.info("Events read from input: %d", n_events_in_file) + self.log.info("Events after applying quality query: %d", n_valid_events_in_file) if len(table) == 0: raise TooFewEvents( f"No events after quality query for telescope type {telescope_type}" ) - table = self.regressor.feature_generator(table, subarray=self.loader.subarray) - - feature_names = self.regressor.features + [self.regressor.target] - table = table[feature_names] - - valid = check_valid_rows(table) - if np.any(~valid): - self.log.warning("Dropping non-predictable events.") - table = table[valid] - n_events = self.n_events.tel[telescope_type] if n_events is not None: if n_events > len(table): From 08495766a277441e23a6f1efbf6b84c1382076fe Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Wed, 25 Oct 2023 17:27:03 +0200 Subject: [PATCH 02/13] Check event validity after merging chunks and make it faster --- ctapipe/tools/train_disp_reconstructor.py | 2 +- ctapipe/tools/train_energy_regressor.py | 12 ++++++------ ctapipe/tools/train_particle_classifier.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 80352a0bb29..1d53c588ae3 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -152,7 +152,7 @@ def _read_table(self, telescope_type): table = table[columns] valid = check_valid_rows(table) - if np.any(~valid): + if not np.all(valid): self.log.warning("Dropping non-predicable events.") table = table[valid] diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index e0f412f169b..11ad2ce548d 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -129,7 +129,7 @@ def _read_table(self, telescope_type): bar = tqdm( chunk_iterator, desc=f"Loading training events for {telescope_type}", - unit=" Array Events", + unit=" Telescope Events", total=chunk_iterator.n_total, ) table = [] @@ -156,11 +156,6 @@ def _read_table(self, telescope_type): feature_names = self.regressor.features + [self.regressor.target] table_chunk = table_chunk[feature_names] - valid = check_valid_rows(table_chunk) - if np.any(~valid): - self.log.warning("Dropping non-predictable events.") - table_chunk = table_chunk[valid] - table.append(table_chunk) bar.update(stop - start) @@ -172,6 +167,11 @@ def _read_table(self, telescope_type): f"No events after quality query for telescope type {telescope_type}" ) + valid = check_valid_rows(table) + if not np.all(valid): + self.log.warning("Dropping non-predictable events.") + table = table[valid] + n_events = self.n_events.tel[telescope_type] if n_events is not None: if n_events > len(table): diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index b8511c6fd1e..21c5e21fdd3 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -184,7 +184,7 @@ def _read_table(self, telescope_type, loader, n_events=None): table = table[columns] valid = check_valid_rows(table) - if np.any(~valid): + if not np.all(valid): self.log.warning("Dropping non-predictable events.") table = table[valid] From 8b60cf5799fdea790e8d674dbe3618e2c4e6ff16 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Thu, 26 Oct 2023 15:27:22 +0200 Subject: [PATCH 03/13] Do all event validation in chunk loop; remove loading bar - Since the total number of telescope events in a file is not known when loading in chunks (only the total number of subarray events), a progress bar for loading does not make sense. --- ctapipe/tools/train_energy_regressor.py | 58 ++++++++++++------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 11ad2ce548d..3032e9237b3 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -3,7 +3,6 @@ """ import numpy as np from astropy.table import vstack -from tqdm.auto import tqdm from ctapipe.core import Tool from ctapipe.core.traits import Int, IntTelescopeParameter, Path @@ -126,51 +125,48 @@ def _read_table(self, telescope_type): self.chunk_size, telescopes=[telescope_type], ) - bar = tqdm( - chunk_iterator, - desc=f"Loading training events for {telescope_type}", - unit=" Telescope Events", - total=chunk_iterator.n_total, - ) table = [] n_events_in_file = 0 n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + mask = self.regressor.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + self.log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), + ) + n_valid_events_in_file += len(table_chunk) - with bar: - for chunk, (start, stop, table_chunk) in enumerate(chunk_iterator): - self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) - n_events_in_file += len(table_chunk) - - mask = self.regressor.quality_query.get_table_mask(table_chunk) - table_chunk = table_chunk[mask] - self.log.debug( - "Events in chunk %d after applying quality_query: %d", - chunk, - len(table_chunk), - ) - n_valid_events_in_file += len(table_chunk) + table_chunk = self.regressor.feature_generator( + table_chunk, subarray=self.loader.subarray + ) + feature_names = self.regressor.features + [self.regressor.target] + table_chunk = table_chunk[feature_names] - table_chunk = self.regressor.feature_generator( - table_chunk, subarray=self.loader.subarray - ) - feature_names = self.regressor.features + [self.regressor.target] - table_chunk = table_chunk[feature_names] + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(valid) + table_chunk = table_chunk[valid] - table.append(table_chunk) - bar.update(stop - start) + table.append(table_chunk) table = vstack(table) self.log.info("Events read from input: %d", n_events_in_file) self.log.info("Events after applying quality query: %d", n_valid_events_in_file) + if len(table) == 0: raise TooFewEvents( f"No events after quality query for telescope type {telescope_type}" ) - valid = check_valid_rows(table) - if not np.all(valid): - self.log.warning("Dropping non-predictable events.") - table = table[valid] + if n_non_predictable > 0: + self.log.warning("Dropping %d non-predictable events.", n_non_predictable) n_events = self.n_events.tel[telescope_type] if n_events is not None: From 88f00b501ebcce6d7a7447018f38c60a4104e892 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Thu, 26 Oct 2023 16:14:35 +0200 Subject: [PATCH 04/13] Chunked loading for training particle clf and training disp reco --- ctapipe/tools/train_disp_reconstructor.py | 80 +++++++++++++++------- ctapipe/tools/train_particle_classifier.py | 67 +++++++++++++----- 2 files changed, 103 insertions(+), 44 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 1d53c588ae3..7ba282c9728 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -3,6 +3,7 @@ """ import astropy.units as u import numpy as np +from astropy.table import vstack from ctapipe.containers import CoordinateFrameType from ctapipe.core import Tool @@ -56,6 +57,12 @@ class TrainDispReconstructor(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help="How many subarray events to load at once before training on n_events.", + ).tag(config=True) + random_seed = Int( default_value=0, help="Random seed for sampling and cross validation" ).tag(config=True) @@ -121,40 +128,61 @@ def start(self): self.log.info("done") def _read_table(self, telescope_type): - table = self.loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" + chunk_iterator = self.loader.read_telescope_events_chunked( + self.chunk_size, + telescopes=[telescope_type], + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + mask = self.models.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + self.log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), ) + n_valid_events_in_file += len(table_chunk) - mask = self.models.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) + if not np.all( + table_chunk["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value + ): + raise ValueError( + "Pointing information for training data has to be provided in horizontal coordinates" + ) - if not np.all( - table["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value - ): - raise ValueError( - "Pointing information for training data has to be provided in horizontal coordinates" + table_chunk = self.models.feature_generator( + table_chunk, subarray=self.loader.subarray ) + table_chunk[self.models.target] = self._get_true_disp(table_chunk) + # Add true energy for energy-dependent performance plots + columns = self.models.features + [self.models.target, "true_energy"] + table_chunk = table_chunk[columns] + + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(valid) + table_chunk = table_chunk[valid] - table = self.models.feature_generator(table, subarray=self.loader.subarray) + table.append(table_chunk) - table[self.models.target] = self._get_true_disp(table) + table = vstack(table) + self.log.info("Events read from input: %d", n_events_in_file) + self.log.info("Events after applying quality query: %d", n_valid_events_in_file) - # Add true energy for energy-dependent performance plots - columns = self.models.features + [self.models.target, "true_energy"] - table = table[columns] + if len(table) == 0: + raise TooFewEvents( + f"No events after quality query for telescope type {telescope_type}" + ) - valid = check_valid_rows(table) - if not np.all(valid): - self.log.warning("Dropping non-predicable events.") - table = table[valid] + if n_non_predictable > 0: + self.log.warning("Dropping %d non-predictable events.", n_non_predictable) n_events = self.n_events.tel[telescope_type] if n_events is not None: diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index 21c5e21fdd3..d3d1d8074db 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -78,6 +78,15 @@ class TrainParticleClassifier(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help=( + "How many subarray events to load at once before training on" + " n_signal and n_background events." + ), + ).tag(config=True) + random_seed = Int( default_value=0, help="Random number seed for sampling and the cross validation splitting", @@ -162,31 +171,53 @@ def start(self): self.log.info("done") def _read_table(self, telescope_type, loader, n_events=None): - table = loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" + chunk_iterator = loader.read_telescope_events_chunked( + self.chunk_size, + telescopes=[telescope_type], + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + mask = self.classifier.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + self.log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), ) + n_valid_events_in_file += len(table_chunk) + + table_chunk = self.classifier.feature_generator( + table_chunk, subarray=self.subarray + ) + # Add true energy for energy-dependent performance plots + columns = self.classifier.features + [self.classifier.target, "true_energy"] + table_chunk = table_chunk[columns] + + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(valid) + table_chunk = table_chunk[valid] + + table.append(table_chunk) + + table = vstack(table) + self.log.info("Events read from input: %d", n_events_in_file) + self.log.info("Events after applying quality query: %d", n_valid_events_in_file) - mask = self.classifier.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) if len(table) == 0: raise TooFewEvents( f"No events after quality query for telescope type {telescope_type}" ) - table = self.classifier.feature_generator(table, subarray=self.subarray) - - # Add true energy for energy-dependent performance plots - columns = self.classifier.features + [self.classifier.target, "true_energy"] - table = table[columns] - - valid = check_valid_rows(table) - if not np.all(valid): - self.log.warning("Dropping non-predictable events.") - table = table[valid] + if n_non_predictable > 0: + self.log.warning("Dropping %d non-predictable events.", n_non_predictable) if n_events is not None: if n_events > len(table): From 525e3ba9a47cc7b0b32ec1ee566aa115806a0b26 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Thu, 26 Oct 2023 16:14:49 +0200 Subject: [PATCH 05/13] Add changelog --- docs/changes/2423.optimization.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/changes/2423.optimization.rst diff --git a/docs/changes/2423.optimization.rst b/docs/changes/2423.optimization.rst new file mode 100644 index 00000000000..b6e1567767a --- /dev/null +++ b/docs/changes/2423.optimization.rst @@ -0,0 +1,3 @@ +Load data and apply event and column selection in chunks in ``ctapipe-train-*`` +before merging afterwards. +This reduces memory usage. From 4ac7709a2139f9bc1fbb01bc28ce328428dcceba Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Mon, 6 Nov 2023 13:02:26 +0100 Subject: [PATCH 06/13] Fix line too long --- ctapipe/tools/train_disp_reconstructor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 7ba282c9728..ec69a3d13b8 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -151,7 +151,8 @@ def _read_table(self, telescope_type): n_valid_events_in_file += len(table_chunk) if not np.all( - table_chunk["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value + table_chunk["subarray_pointing_frame"] + == CoordinateFrameType.ALTAZ.value ): raise ValueError( "Pointing information for training data has to be provided in horizontal coordinates" From 3fae0b0af16ea8da5033611c4090d587e8ce0dff Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Mon, 6 Nov 2023 14:16:59 +0100 Subject: [PATCH 07/13] Count invalid events as non-predictable, not valid events --- ctapipe/tools/train_disp_reconstructor.py | 2 +- ctapipe/tools/train_energy_regressor.py | 2 +- ctapipe/tools/train_particle_classifier.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index ec69a3d13b8..6b3ad6d0314 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -168,7 +168,7 @@ def _read_table(self, telescope_type): valid = check_valid_rows(table_chunk) if not np.all(valid): - n_non_predictable += np.sum(valid) + n_non_predictable += np.sum(~valid) table_chunk = table_chunk[valid] table.append(table_chunk) diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 3032e9237b3..7fe7986b380 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -151,7 +151,7 @@ def _read_table(self, telescope_type): valid = check_valid_rows(table_chunk) if not np.all(valid): - n_non_predictable += np.sum(valid) + n_non_predictable += np.sum(~valid) table_chunk = table_chunk[valid] table.append(table_chunk) diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index d3d1d8074db..76865cd5366 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -202,7 +202,7 @@ def _read_table(self, telescope_type, loader, n_events=None): valid = check_valid_rows(table_chunk) if not np.all(valid): - n_non_predictable += np.sum(valid) + n_non_predictable += np.sum(~valid) table_chunk = table_chunk[valid] table.append(table_chunk) From 8378645763482fac1f0239cc81fabc68eb0eaa75 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Tue, 7 Nov 2023 10:42:08 +0100 Subject: [PATCH 08/13] Refactor reading training data into standalone funtion --- ctapipe/tools/train_disp_reconstructor.py | 104 +++++---------------- ctapipe/tools/train_energy_regressor.py | 81 +++------------- ctapipe/tools/train_particle_classifier.py | 94 +++++-------------- ctapipe/tools/utils.py | 89 ++++++++++++++++++ 4 files changed, 151 insertions(+), 217 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 6b3ad6d0314..715e42a3cf5 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -3,15 +3,14 @@ """ import astropy.units as u import numpy as np -from astropy.table import vstack -from ctapipe.containers import CoordinateFrameType from ctapipe.core import Tool from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path -from ctapipe.exceptions import TooFewEvents from ctapipe.io import TableLoader from ctapipe.reco import CrossValidator, DispReconstructor -from ctapipe.reco.preprocessing import check_valid_rows, horizontal_to_telescope +from ctapipe.reco.preprocessing import horizontal_to_telescope + +from .utils import read_training_events __all__ = [ "TrainDispReconstructor", @@ -118,7 +117,28 @@ def start(self): self.log.info("Training models for %d types", len(types)) for tel_type in types: self.log.info("Loading events for %s", tel_type) - table = self._read_table(tel_type) + feature_names = self.models.features + [ + "true_energy", + "subarray_pointing_lat", + "subarray_pointing_lon", + "true_alt", + "true_az", + "hillas_fov_lat", + "hillas_fov_lon", + "hillas_psi", + ] + table = read_training_events( + self.loader, + self.chunk_size, + tel_type, + self.models, + feature_names, + self.log, + self.rng, + self.n_events.tel[tel_type], + ) + table[self.models.target] = self._get_true_disp(table) + table = table[self.models.features + [self.models.target, "true_energy"]] self.log.info("Train models on %s events", len(table)) self.cross_validate(tel_type, table) @@ -127,80 +147,6 @@ def start(self): self.models.fit(tel_type, table) self.log.info("done") - def _read_table(self, telescope_type): - chunk_iterator = self.loader.read_telescope_events_chunked( - self.chunk_size, - telescopes=[telescope_type], - ) - table = [] - n_events_in_file = 0 - n_valid_events_in_file = 0 - n_non_predictable = 0 - - for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): - self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) - n_events_in_file += len(table_chunk) - - mask = self.models.quality_query.get_table_mask(table_chunk) - table_chunk = table_chunk[mask] - self.log.debug( - "Events in chunk %d after applying quality_query: %d", - chunk, - len(table_chunk), - ) - n_valid_events_in_file += len(table_chunk) - - if not np.all( - table_chunk["subarray_pointing_frame"] - == CoordinateFrameType.ALTAZ.value - ): - raise ValueError( - "Pointing information for training data has to be provided in horizontal coordinates" - ) - - table_chunk = self.models.feature_generator( - table_chunk, subarray=self.loader.subarray - ) - table_chunk[self.models.target] = self._get_true_disp(table_chunk) - # Add true energy for energy-dependent performance plots - columns = self.models.features + [self.models.target, "true_energy"] - table_chunk = table_chunk[columns] - - valid = check_valid_rows(table_chunk) - if not np.all(valid): - n_non_predictable += np.sum(~valid) - table_chunk = table_chunk[valid] - - table.append(table_chunk) - - table = vstack(table) - self.log.info("Events read from input: %d", n_events_in_file) - self.log.info("Events after applying quality query: %d", n_valid_events_in_file) - - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) - - if n_non_predictable > 0: - self.log.warning("Dropping %d non-predictable events.", n_non_predictable) - - n_events = self.n_events.tel[telescope_type] - if n_events is not None: - if n_events > len(table): - self.log.warning( - "Number of events in table (%d) is less than requested number of events %d", - len(table), - n_events, - ) - else: - self.log.info("Sampling %d events", n_events) - idx = self.rng.choice(len(table), n_events, replace=False) - idx.sort() - table = table[idx] - - return table - def _get_true_disp(self, table): fov_lon, fov_lat = horizontal_to_telescope( alt=table["true_alt"], diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 7fe7986b380..8aa9fd753da 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -2,14 +2,13 @@ Tool for training the EnergyRegressor """ import numpy as np -from astropy.table import vstack from ctapipe.core import Tool from ctapipe.core.traits import Int, IntTelescopeParameter, Path -from ctapipe.exceptions import TooFewEvents from ctapipe.io import TableLoader from ctapipe.reco import CrossValidator, EnergyRegressor -from ctapipe.reco.preprocessing import check_valid_rows + +from .utils import read_training_events __all__ = [ "TrainEnergyRegressor", @@ -111,7 +110,17 @@ def start(self): self.log.info("Training models for %d types", len(types)) for tel_type in types: self.log.info("Loading events for %s", tel_type) - table = self._read_table(tel_type) + feature_names = self.regressor.features + [self.regressor.target] + table = read_training_events( + self.loader, + self.chunk_size, + tel_type, + self.regressor, + feature_names, + self.log, + self.rng, + self.n_events.tel[tel_type], + ) self.log.info("Train on %s events", len(table)) self.cross_validate(tel_type, table) @@ -120,70 +129,6 @@ def start(self): self.regressor.fit(tel_type, table) self.log.info("done") - def _read_table(self, telescope_type): - chunk_iterator = self.loader.read_telescope_events_chunked( - self.chunk_size, - telescopes=[telescope_type], - ) - table = [] - n_events_in_file = 0 - n_valid_events_in_file = 0 - n_non_predictable = 0 - - for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): - self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) - n_events_in_file += len(table_chunk) - - mask = self.regressor.quality_query.get_table_mask(table_chunk) - table_chunk = table_chunk[mask] - self.log.debug( - "Events in chunk %d after applying quality_query: %d", - chunk, - len(table_chunk), - ) - n_valid_events_in_file += len(table_chunk) - - table_chunk = self.regressor.feature_generator( - table_chunk, subarray=self.loader.subarray - ) - feature_names = self.regressor.features + [self.regressor.target] - table_chunk = table_chunk[feature_names] - - valid = check_valid_rows(table_chunk) - if not np.all(valid): - n_non_predictable += np.sum(~valid) - table_chunk = table_chunk[valid] - - table.append(table_chunk) - - table = vstack(table) - self.log.info("Events read from input: %d", n_events_in_file) - self.log.info("Events after applying quality query: %d", n_valid_events_in_file) - - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) - - if n_non_predictable > 0: - self.log.warning("Dropping %d non-predictable events.", n_non_predictable) - - n_events = self.n_events.tel[telescope_type] - if n_events is not None: - if n_events > len(table): - self.log.warning( - "Number of events in table (%d) is less than requested number of events %d", - len(table), - n_events, - ) - else: - self.log.info("Sampling %d events", n_events) - idx = self.rng.choice(len(table), n_events, replace=False) - idx.sort() - table = table[idx] - - return table - def finish(self): """ Write-out trained models and cross-validation results. diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index 76865cd5366..aa56d8ec723 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -6,10 +6,10 @@ from ctapipe.core.tool import Tool from ctapipe.core.traits import Int, IntTelescopeParameter, Path -from ctapipe.exceptions import TooFewEvents from ctapipe.io import TableLoader from ctapipe.reco import CrossValidator, ParticleClassifier -from ctapipe.reco.preprocessing import check_valid_rows + +from .utils import read_training_events __all__ = [ "TrainParticleClassifier", @@ -170,76 +170,30 @@ def start(self): self.classifier.fit(tel_type, table) self.log.info("done") - def _read_table(self, telescope_type, loader, n_events=None): - chunk_iterator = loader.read_telescope_events_chunked( - self.chunk_size, - telescopes=[telescope_type], - ) - table = [] - n_events_in_file = 0 - n_valid_events_in_file = 0 - n_non_predictable = 0 - - for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): - self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) - n_events_in_file += len(table_chunk) - - mask = self.classifier.quality_query.get_table_mask(table_chunk) - table_chunk = table_chunk[mask] - self.log.debug( - "Events in chunk %d after applying quality_query: %d", - chunk, - len(table_chunk), - ) - n_valid_events_in_file += len(table_chunk) - - table_chunk = self.classifier.feature_generator( - table_chunk, subarray=self.subarray - ) - # Add true energy for energy-dependent performance plots - columns = self.classifier.features + [self.classifier.target, "true_energy"] - table_chunk = table_chunk[columns] - - valid = check_valid_rows(table_chunk) - if not np.all(valid): - n_non_predictable += np.sum(~valid) - table_chunk = table_chunk[valid] - - table.append(table_chunk) - - table = vstack(table) - self.log.info("Events read from input: %d", n_events_in_file) - self.log.info("Events after applying quality query: %d", n_valid_events_in_file) - - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) - - if n_non_predictable > 0: - self.log.warning("Dropping %d non-predictable events.", n_non_predictable) - - if n_events is not None: - if n_events > len(table): - self.log.warning( - "Number of events in table (%d) is less than requested number of events %d", - len(table), - n_events, - ) - else: - self.log.info("Sampling %d events", n_events) - idx = self.rng.choice(len(table), n_events, replace=False) - idx.sort() - table = table[idx] - - return table - def _read_input_data(self, tel_type): - signal = self._read_table( - tel_type, self.signal_loader, self.n_signal.tel[tel_type] + feature_names = self.classifier.features + [ + self.classifier.target, + "true_energy", + ] + signal = read_training_events( + self.signal_loader, + self.chunk_size, + tel_type, + self.classifier, + feature_names, + self.log, + self.rng, + self.n_signal.tel[tel_type], ) - background = self._read_table( - tel_type, self.background_loader, self.n_background.tel[tel_type] + background = read_training_events( + self.background_loader, + self.chunk_size, + tel_type, + self.classifier, + feature_names, + self.log, + self.rng, + self.n_signal.tel[tel_type], ) table = vstack([signal, background]) self.log.info( diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py index 131046fbc1c..b2bd0223953 100644 --- a/ctapipe/tools/utils.py +++ b/ctapipe/tools/utils.py @@ -5,6 +5,14 @@ import sys from collections import OrderedDict +import numpy as np +from astropy.table import vstack + +from ctapipe.containers import CoordinateFrameType +from ctapipe.exceptions import TooFewEvents +from ctapipe.reco.preprocessing import check_valid_rows +from ctapipe.reco.sklearn import DispReconstructor + if sys.version_info < (3, 10): from importlib_metadata import distribution else: @@ -71,3 +79,84 @@ def get_all_descriptions(): descriptions[name] = "[no documentation. Please add a docstring]" return descriptions + + +def read_training_events( + loader, + chunk_size, + telescope_type, + reconstructor, + feature_names, + logger, + rng, + n_events=None, +): + chunk_iterator = loader.read_telescope_events_chunked( + chunk_size, + telescopes=[telescope_type], + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + logger.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + if isinstance(reconstructor, DispReconstructor): + if not np.all( + table_chunk["subarray_pointing_frame"] + == CoordinateFrameType.ALTAZ.value + ): + raise ValueError( + "Pointing information for training data has to be provided in horizontal coordinates" + ) + + mask = reconstructor.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + logger.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), + ) + n_valid_events_in_file += len(table_chunk) + + table_chunk = reconstructor.feature_generator( + table_chunk, subarray=loader.subarray + ) + table_chunk = table_chunk[feature_names] + + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(~valid) + table_chunk = table_chunk[valid] + + table.append(table_chunk) + + table = vstack(table) + logger.info("Events read from input: %d", n_events_in_file) + logger.info("Events after applying quality query: %d", n_valid_events_in_file) + + if len(table) == 0: + raise TooFewEvents( + f"No events after quality query for telescope type {telescope_type}" + ) + + if n_non_predictable > 0: + logger.warning("Dropping %d non-predictable events.", n_non_predictable) + + if n_events is not None: + if n_events > len(table): + logger.warning( + "Number of events in table (%d) is less than requested number of events %d", + len(table), + n_events, + ) + else: + logger.info("Sampling %d events", n_events) + idx = rng.choice(len(table), n_events, replace=False) + idx.sort() + table = table[idx] + + return table From a3891834a884db5b864b56491c4e099ce8d742fa Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Fri, 10 Nov 2023 12:20:59 +0100 Subject: [PATCH 09/13] Add docstring --- ctapipe/tools/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py index b2bd0223953..be946e6be78 100644 --- a/ctapipe/tools/utils.py +++ b/ctapipe/tools/utils.py @@ -91,6 +91,7 @@ def read_training_events( rng, n_events=None, ): + """Chunked loading of events for training ML models""" chunk_iterator = loader.read_telescope_events_chunked( chunk_size, telescopes=[telescope_type], From b52be8427321e65cf30e5400ff484359b59d7de5 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Tue, 14 Nov 2023 10:23:40 +0100 Subject: [PATCH 10/13] Add type hints --- ctapipe/tools/utils.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py index be946e6be78..379b6b0bde8 100644 --- a/ctapipe/tools/utils.py +++ b/ctapipe/tools/utils.py @@ -4,14 +4,18 @@ import importlib import sys from collections import OrderedDict +from logging import Logger import numpy as np from astropy.table import vstack from ctapipe.containers import CoordinateFrameType +from ctapipe.core.traits import Int from ctapipe.exceptions import TooFewEvents +from ctapipe.instrument.telescope import TelescopeDescription +from ctapipe.io import TableLoader from ctapipe.reco.preprocessing import check_valid_rows -from ctapipe.reco.sklearn import DispReconstructor +from ctapipe.reco.sklearn import DispReconstructor, SKLearnReconstructor if sys.version_info < (3, 10): from importlib_metadata import distribution @@ -82,13 +86,13 @@ def get_all_descriptions(): def read_training_events( - loader, - chunk_size, - telescope_type, - reconstructor, - feature_names, - logger, - rng, + loader: TableLoader, + chunk_size: Int, + telescope_type: TelescopeDescription, + reconstructor: SKLearnReconstructor, + feature_names: list, + logger: Logger, + rng: np.random.Generator, n_events=None, ): """Chunked loading of events for training ML models""" From f93e2459ce91c9211fae1043a3df5894b2a8bff7 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Tue, 14 Nov 2023 11:24:26 +0100 Subject: [PATCH 11/13] Use relative imports, default logger, and correct type hint --- ctapipe/tools/train_disp_reconstructor.py | 2 +- ctapipe/tools/train_energy_regressor.py | 2 +- ctapipe/tools/train_particle_classifier.py | 4 +- ctapipe/tools/utils.py | 43 ++++++++++++---------- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 715e42a3cf5..72a551e8597 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -133,8 +133,8 @@ def start(self): tel_type, self.models, feature_names, - self.log, self.rng, + self.log, self.n_events.tel[tel_type], ) table[self.models.target] = self._get_true_disp(table) diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 8aa9fd753da..9db47e1161e 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -117,8 +117,8 @@ def start(self): tel_type, self.regressor, feature_names, - self.log, self.rng, + self.log, self.n_events.tel[tel_type], ) diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index aa56d8ec723..59656d8441c 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -181,8 +181,8 @@ def _read_input_data(self, tel_type): tel_type, self.classifier, feature_names, - self.log, self.rng, + self.log, self.n_signal.tel[tel_type], ) background = read_training_events( @@ -191,8 +191,8 @@ def _read_input_data(self, tel_type): tel_type, self.classifier, feature_names, - self.log, self.rng, + self.log, self.n_signal.tel[tel_type], ) table = vstack([signal, background]) diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py index 379b6b0bde8..9e7c76eb837 100644 --- a/ctapipe/tools/utils.py +++ b/ctapipe/tools/utils.py @@ -2,20 +2,23 @@ """Utils to create scripts and command-line tools""" import argparse import importlib +import logging import sys from collections import OrderedDict -from logging import Logger +from typing import Type import numpy as np from astropy.table import vstack -from ctapipe.containers import CoordinateFrameType -from ctapipe.core.traits import Int -from ctapipe.exceptions import TooFewEvents -from ctapipe.instrument.telescope import TelescopeDescription -from ctapipe.io import TableLoader -from ctapipe.reco.preprocessing import check_valid_rows -from ctapipe.reco.sklearn import DispReconstructor, SKLearnReconstructor +from ..containers import CoordinateFrameType +from ..core.traits import Int +from ..exceptions import TooFewEvents +from ..instrument.telescope import TelescopeDescription +from ..io import TableLoader +from ..reco.preprocessing import check_valid_rows +from ..reco.sklearn import DispReconstructor, SKLearnReconstructor + +LOG = logging.getLogger(__name__) if sys.version_info < (3, 10): from importlib_metadata import distribution @@ -89,10 +92,10 @@ def read_training_events( loader: TableLoader, chunk_size: Int, telescope_type: TelescopeDescription, - reconstructor: SKLearnReconstructor, + reconstructor: Type[SKLearnReconstructor], feature_names: list, - logger: Logger, rng: np.random.Generator, + log=LOG, n_events=None, ): """Chunked loading of events for training ML models""" @@ -106,7 +109,7 @@ def read_training_events( n_non_predictable = 0 for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): - logger.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) n_events_in_file += len(table_chunk) if isinstance(reconstructor, DispReconstructor): @@ -115,12 +118,13 @@ def read_training_events( == CoordinateFrameType.ALTAZ.value ): raise ValueError( - "Pointing information for training data has to be provided in horizontal coordinates" + "Pointing information for training data" + " has to be provided in horizontal coordinates" ) mask = reconstructor.quality_query.get_table_mask(table_chunk) table_chunk = table_chunk[mask] - logger.debug( + log.debug( "Events in chunk %d after applying quality_query: %d", chunk, len(table_chunk), @@ -140,8 +144,8 @@ def read_training_events( table.append(table_chunk) table = vstack(table) - logger.info("Events read from input: %d", n_events_in_file) - logger.info("Events after applying quality query: %d", n_valid_events_in_file) + log.info("Events read from input: %d", n_events_in_file) + log.info("Events after applying quality query: %d", n_valid_events_in_file) if len(table) == 0: raise TooFewEvents( @@ -149,17 +153,18 @@ def read_training_events( ) if n_non_predictable > 0: - logger.warning("Dropping %d non-predictable events.", n_non_predictable) + log.warning("Dropping %d non-predictable events.", n_non_predictable) if n_events is not None: if n_events > len(table): - logger.warning( - "Number of events in table (%d) is less than requested number of events %d", + log.warning( + "Number of events in table (%d) is less" + " than requested number of events %d", len(table), n_events, ) else: - logger.info("Sampling %d events", n_events) + log.info("Sampling %d events", n_events) idx = rng.choice(len(table), n_events, replace=False) idx.sort() table = table[idx] From b61576e2ac2104880355f4590e1b236d98d2b999 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Tue, 14 Nov 2023 12:15:23 +0100 Subject: [PATCH 12/13] Use keyword arguments --- ctapipe/tools/train_disp_reconstructor.py | 16 +++++------ ctapipe/tools/train_energy_regressor.py | 16 +++++------ ctapipe/tools/train_particle_classifier.py | 32 +++++++++++----------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 72a551e8597..a188b4260d0 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -128,14 +128,14 @@ def start(self): "hillas_psi", ] table = read_training_events( - self.loader, - self.chunk_size, - tel_type, - self.models, - feature_names, - self.rng, - self.log, - self.n_events.tel[tel_type], + loader=self.loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.models, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_events.tel[tel_type], ) table[self.models.target] = self._get_true_disp(table) table = table[self.models.features + [self.models.target, "true_energy"]] diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 9db47e1161e..2c9d411f594 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -112,14 +112,14 @@ def start(self): self.log.info("Loading events for %s", tel_type) feature_names = self.regressor.features + [self.regressor.target] table = read_training_events( - self.loader, - self.chunk_size, - tel_type, - self.regressor, - feature_names, - self.rng, - self.log, - self.n_events.tel[tel_type], + loader=self.loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.regressor, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_events.tel[tel_type], ) self.log.info("Train on %s events", len(table)) diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index 59656d8441c..90324b5d5b8 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -176,24 +176,24 @@ def _read_input_data(self, tel_type): "true_energy", ] signal = read_training_events( - self.signal_loader, - self.chunk_size, - tel_type, - self.classifier, - feature_names, - self.rng, - self.log, - self.n_signal.tel[tel_type], + loader=self.signal_loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.classifier, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_signal.tel[tel_type], ) background = read_training_events( - self.background_loader, - self.chunk_size, - tel_type, - self.classifier, - feature_names, - self.rng, - self.log, - self.n_signal.tel[tel_type], + loader=self.background_loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.classifier, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_signal.tel[tel_type], ) table = vstack([signal, background]) self.log.info( From bb01357995f58fc742ead3162eb2e1e10f98e087 Mon Sep 17 00:00:00 2001 From: Lukas Beiske <lukas.beiske@tu-dortmund.de> Date: Tue, 14 Nov 2023 14:08:59 +0100 Subject: [PATCH 13/13] Use n_background for background events --- ctapipe/tools/train_particle_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index 90324b5d5b8..70c53b63419 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -193,7 +193,7 @@ def _read_input_data(self, tel_type): feature_names=feature_names, rng=self.rng, log=self.log, - n_events=self.n_signal.tel[tel_type], + n_events=self.n_background.tel[tel_type], ) table = vstack([signal, background]) self.log.info(