From bb7cc84150e4a60cf0a3681aad93798c29f0c404 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Wed, 25 Oct 2023 16:19:19 +0200
Subject: [PATCH 01/13] Load training data in chunks when training an energy
 regressor

---
 ctapipe/tools/train_energy_regressor.py | 72 ++++++++++++++++++-------
 1 file changed, 53 insertions(+), 19 deletions(-)

diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index 89647f6c39b..e0f412f169b 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -2,6 +2,8 @@
 Tool for training the EnergyRegressor
 """
 import numpy as np
+from astropy.table import vstack
+from tqdm.auto import tqdm
 
 from ctapipe.core import Tool
 from ctapipe.core.traits import Int, IntTelescopeParameter, Path
@@ -53,6 +55,12 @@ class TrainEnergyRegressor(Tool):
         ),
     ).tag(config=True)
 
+    chunk_size = Int(
+        default_value=100000,
+        allow_none=True,
+        help="How many subarray events to load at once before training on n_events.",
+    ).tag(config=True)
+
     random_seed = Int(
         default_value=0, help="Random seed for sampling and cross validation"
     ).tag(config=True)
@@ -61,6 +69,7 @@ class TrainEnergyRegressor(Tool):
         ("i", "input"): "TableLoader.input_url",
         ("o", "output"): "TrainEnergyRegressor.output_path",
         "n-events": "TrainEnergyRegressor.n_events",
+        "chunk-size": "TrainEnergyRegressor.chunk_size",
         "cv-output": "CrossValidator.output_path",
     }
 
@@ -113,31 +122,56 @@ def start(self):
             self.log.info("done")
 
     def _read_table(self, telescope_type):
-        table = self.loader.read_telescope_events([telescope_type])
-        self.log.info("Events read from input: %d", len(table))
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"Input file does not contain any events for telescope type {telescope_type}"
-            )
+        chunk_iterator = self.loader.read_telescope_events_chunked(
+            self.chunk_size,
+            telescopes=[telescope_type],
+        )
+        bar = tqdm(
+            chunk_iterator,
+            desc=f"Loading training events for {telescope_type}",
+            unit=" Array Events",
+            total=chunk_iterator.n_total,
+        )
+        table = []
+        n_events_in_file = 0
+        n_valid_events_in_file = 0
+
+        with bar:
+            for chunk, (start, stop, table_chunk) in enumerate(chunk_iterator):
+                self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
+                n_events_in_file += len(table_chunk)
+
+                mask = self.regressor.quality_query.get_table_mask(table_chunk)
+                table_chunk = table_chunk[mask]
+                self.log.debug(
+                    "Events in chunk %d after applying quality_query: %d",
+                    chunk,
+                    len(table_chunk),
+                )
+                n_valid_events_in_file += len(table_chunk)
+
+                table_chunk = self.regressor.feature_generator(
+                    table_chunk, subarray=self.loader.subarray
+                )
+                feature_names = self.regressor.features + [self.regressor.target]
+                table_chunk = table_chunk[feature_names]
+
+                valid = check_valid_rows(table_chunk)
+                if np.any(~valid):
+                    self.log.warning("Dropping non-predictable events.")
+                    table_chunk = table_chunk[valid]
 
-        mask = self.regressor.quality_query.get_table_mask(table)
-        table = table[mask]
-        self.log.info("Events after applying quality query: %d", len(table))
+                table.append(table_chunk)
+                bar.update(stop - start)
+
+        table = vstack(table)
+        self.log.info("Events read from input: %d", n_events_in_file)
+        self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
         if len(table) == 0:
             raise TooFewEvents(
                 f"No events after quality query for telescope type {telescope_type}"
             )
 
-        table = self.regressor.feature_generator(table, subarray=self.loader.subarray)
-
-        feature_names = self.regressor.features + [self.regressor.target]
-        table = table[feature_names]
-
-        valid = check_valid_rows(table)
-        if np.any(~valid):
-            self.log.warning("Dropping non-predictable events.")
-            table = table[valid]
-
         n_events = self.n_events.tel[telescope_type]
         if n_events is not None:
             if n_events > len(table):

From 08495766a277441e23a6f1efbf6b84c1382076fe Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Wed, 25 Oct 2023 17:27:03 +0200
Subject: [PATCH 02/13] Check event validity after merging chunks and make it
 faster

---
 ctapipe/tools/train_disp_reconstructor.py  |  2 +-
 ctapipe/tools/train_energy_regressor.py    | 12 ++++++------
 ctapipe/tools/train_particle_classifier.py |  2 +-
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index 80352a0bb29..1d53c588ae3 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -152,7 +152,7 @@ def _read_table(self, telescope_type):
         table = table[columns]
 
         valid = check_valid_rows(table)
-        if np.any(~valid):
+        if not np.all(valid):
             self.log.warning("Dropping non-predicable events.")
             table = table[valid]
 
diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index e0f412f169b..11ad2ce548d 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -129,7 +129,7 @@ def _read_table(self, telescope_type):
         bar = tqdm(
             chunk_iterator,
             desc=f"Loading training events for {telescope_type}",
-            unit=" Array Events",
+            unit=" Telescope Events",
             total=chunk_iterator.n_total,
         )
         table = []
@@ -156,11 +156,6 @@ def _read_table(self, telescope_type):
                 feature_names = self.regressor.features + [self.regressor.target]
                 table_chunk = table_chunk[feature_names]
 
-                valid = check_valid_rows(table_chunk)
-                if np.any(~valid):
-                    self.log.warning("Dropping non-predictable events.")
-                    table_chunk = table_chunk[valid]
-
                 table.append(table_chunk)
                 bar.update(stop - start)
 
@@ -172,6 +167,11 @@ def _read_table(self, telescope_type):
                 f"No events after quality query for telescope type {telescope_type}"
             )
 
+        valid = check_valid_rows(table)
+        if not np.all(valid):
+            self.log.warning("Dropping non-predictable events.")
+            table = table[valid]
+
         n_events = self.n_events.tel[telescope_type]
         if n_events is not None:
             if n_events > len(table):
diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index b8511c6fd1e..21c5e21fdd3 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -184,7 +184,7 @@ def _read_table(self, telescope_type, loader, n_events=None):
         table = table[columns]
 
         valid = check_valid_rows(table)
-        if np.any(~valid):
+        if not np.all(valid):
             self.log.warning("Dropping non-predictable events.")
             table = table[valid]
 

From 8b60cf5799fdea790e8d674dbe3618e2c4e6ff16 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Thu, 26 Oct 2023 15:27:22 +0200
Subject: [PATCH 03/13] Do all event validation in chunk loop; remove loading
 bar

- Since the total number of telescope events in a file is not known when
  loading in chunks (only the total number of subarray events),
  a progress bar for loading does not make sense.
---
 ctapipe/tools/train_energy_regressor.py | 58 ++++++++++++-------------
 1 file changed, 27 insertions(+), 31 deletions(-)

diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index 11ad2ce548d..3032e9237b3 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -3,7 +3,6 @@
 """
 import numpy as np
 from astropy.table import vstack
-from tqdm.auto import tqdm
 
 from ctapipe.core import Tool
 from ctapipe.core.traits import Int, IntTelescopeParameter, Path
@@ -126,51 +125,48 @@ def _read_table(self, telescope_type):
             self.chunk_size,
             telescopes=[telescope_type],
         )
-        bar = tqdm(
-            chunk_iterator,
-            desc=f"Loading training events for {telescope_type}",
-            unit=" Telescope Events",
-            total=chunk_iterator.n_total,
-        )
         table = []
         n_events_in_file = 0
         n_valid_events_in_file = 0
+        n_non_predictable = 0
+
+        for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
+            self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
+            n_events_in_file += len(table_chunk)
+
+            mask = self.regressor.quality_query.get_table_mask(table_chunk)
+            table_chunk = table_chunk[mask]
+            self.log.debug(
+                "Events in chunk %d after applying quality_query: %d",
+                chunk,
+                len(table_chunk),
+            )
+            n_valid_events_in_file += len(table_chunk)
 
-        with bar:
-            for chunk, (start, stop, table_chunk) in enumerate(chunk_iterator):
-                self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
-                n_events_in_file += len(table_chunk)
-
-                mask = self.regressor.quality_query.get_table_mask(table_chunk)
-                table_chunk = table_chunk[mask]
-                self.log.debug(
-                    "Events in chunk %d after applying quality_query: %d",
-                    chunk,
-                    len(table_chunk),
-                )
-                n_valid_events_in_file += len(table_chunk)
+            table_chunk = self.regressor.feature_generator(
+                table_chunk, subarray=self.loader.subarray
+            )
+            feature_names = self.regressor.features + [self.regressor.target]
+            table_chunk = table_chunk[feature_names]
 
-                table_chunk = self.regressor.feature_generator(
-                    table_chunk, subarray=self.loader.subarray
-                )
-                feature_names = self.regressor.features + [self.regressor.target]
-                table_chunk = table_chunk[feature_names]
+            valid = check_valid_rows(table_chunk)
+            if not np.all(valid):
+                n_non_predictable += np.sum(valid)
+                table_chunk = table_chunk[valid]
 
-                table.append(table_chunk)
-                bar.update(stop - start)
+            table.append(table_chunk)
 
         table = vstack(table)
         self.log.info("Events read from input: %d", n_events_in_file)
         self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
+
         if len(table) == 0:
             raise TooFewEvents(
                 f"No events after quality query for telescope type {telescope_type}"
             )
 
-        valid = check_valid_rows(table)
-        if not np.all(valid):
-            self.log.warning("Dropping non-predictable events.")
-            table = table[valid]
+        if n_non_predictable > 0:
+            self.log.warning("Dropping %d non-predictable events.", n_non_predictable)
 
         n_events = self.n_events.tel[telescope_type]
         if n_events is not None:

From 88f00b501ebcce6d7a7447018f38c60a4104e892 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Thu, 26 Oct 2023 16:14:35 +0200
Subject: [PATCH 04/13] Chunked loading for training particle clf and training
 disp reco

---
 ctapipe/tools/train_disp_reconstructor.py  | 80 +++++++++++++++-------
 ctapipe/tools/train_particle_classifier.py | 67 +++++++++++++-----
 2 files changed, 103 insertions(+), 44 deletions(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index 1d53c588ae3..7ba282c9728 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -3,6 +3,7 @@
 """
 import astropy.units as u
 import numpy as np
+from astropy.table import vstack
 
 from ctapipe.containers import CoordinateFrameType
 from ctapipe.core import Tool
@@ -56,6 +57,12 @@ class TrainDispReconstructor(Tool):
         ),
     ).tag(config=True)
 
+    chunk_size = Int(
+        default_value=100000,
+        allow_none=True,
+        help="How many subarray events to load at once before training on n_events.",
+    ).tag(config=True)
+
     random_seed = Int(
         default_value=0, help="Random seed for sampling and cross validation"
     ).tag(config=True)
@@ -121,40 +128,61 @@ def start(self):
             self.log.info("done")
 
     def _read_table(self, telescope_type):
-        table = self.loader.read_telescope_events([telescope_type])
-        self.log.info("Events read from input: %d", len(table))
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"Input file does not contain any events for telescope type {telescope_type}"
+        chunk_iterator = self.loader.read_telescope_events_chunked(
+            self.chunk_size,
+            telescopes=[telescope_type],
+        )
+        table = []
+        n_events_in_file = 0
+        n_valid_events_in_file = 0
+        n_non_predictable = 0
+
+        for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
+            self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
+            n_events_in_file += len(table_chunk)
+
+            mask = self.models.quality_query.get_table_mask(table_chunk)
+            table_chunk = table_chunk[mask]
+            self.log.debug(
+                "Events in chunk %d after applying quality_query: %d",
+                chunk,
+                len(table_chunk),
             )
+            n_valid_events_in_file += len(table_chunk)
 
-        mask = self.models.quality_query.get_table_mask(table)
-        table = table[mask]
-        self.log.info("Events after applying quality query: %d", len(table))
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"No events after quality query for telescope type {telescope_type}"
-            )
+            if not np.all(
+                table_chunk["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value
+            ):
+                raise ValueError(
+                    "Pointing information for training data has to be provided in horizontal coordinates"
+                )
 
-        if not np.all(
-            table["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value
-        ):
-            raise ValueError(
-                "Pointing information for training data has to be provided in horizontal coordinates"
+            table_chunk = self.models.feature_generator(
+                table_chunk, subarray=self.loader.subarray
             )
+            table_chunk[self.models.target] = self._get_true_disp(table_chunk)
+            # Add true energy for energy-dependent performance plots
+            columns = self.models.features + [self.models.target, "true_energy"]
+            table_chunk = table_chunk[columns]
+
+            valid = check_valid_rows(table_chunk)
+            if not np.all(valid):
+                n_non_predictable += np.sum(valid)
+                table_chunk = table_chunk[valid]
 
-        table = self.models.feature_generator(table, subarray=self.loader.subarray)
+            table.append(table_chunk)
 
-        table[self.models.target] = self._get_true_disp(table)
+        table = vstack(table)
+        self.log.info("Events read from input: %d", n_events_in_file)
+        self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
 
-        # Add true energy for energy-dependent performance plots
-        columns = self.models.features + [self.models.target, "true_energy"]
-        table = table[columns]
+        if len(table) == 0:
+            raise TooFewEvents(
+                f"No events after quality query for telescope type {telescope_type}"
+            )
 
-        valid = check_valid_rows(table)
-        if not np.all(valid):
-            self.log.warning("Dropping non-predicable events.")
-            table = table[valid]
+        if n_non_predictable > 0:
+            self.log.warning("Dropping %d non-predictable events.", n_non_predictable)
 
         n_events = self.n_events.tel[telescope_type]
         if n_events is not None:
diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index 21c5e21fdd3..d3d1d8074db 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -78,6 +78,15 @@ class TrainParticleClassifier(Tool):
         ),
     ).tag(config=True)
 
+    chunk_size = Int(
+        default_value=100000,
+        allow_none=True,
+        help=(
+            "How many subarray events to load at once before training on"
+            " n_signal and n_background events."
+        ),
+    ).tag(config=True)
+
     random_seed = Int(
         default_value=0,
         help="Random number seed for sampling and the cross validation splitting",
@@ -162,31 +171,53 @@ def start(self):
             self.log.info("done")
 
     def _read_table(self, telescope_type, loader, n_events=None):
-        table = loader.read_telescope_events([telescope_type])
-        self.log.info("Events read from input: %d", len(table))
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"Input file does not contain any events for telescope type {telescope_type}"
+        chunk_iterator = loader.read_telescope_events_chunked(
+            self.chunk_size,
+            telescopes=[telescope_type],
+        )
+        table = []
+        n_events_in_file = 0
+        n_valid_events_in_file = 0
+        n_non_predictable = 0
+
+        for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
+            self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
+            n_events_in_file += len(table_chunk)
+
+            mask = self.classifier.quality_query.get_table_mask(table_chunk)
+            table_chunk = table_chunk[mask]
+            self.log.debug(
+                "Events in chunk %d after applying quality_query: %d",
+                chunk,
+                len(table_chunk),
             )
+            n_valid_events_in_file += len(table_chunk)
+
+            table_chunk = self.classifier.feature_generator(
+                table_chunk, subarray=self.subarray
+            )
+            # Add true energy for energy-dependent performance plots
+            columns = self.classifier.features + [self.classifier.target, "true_energy"]
+            table_chunk = table_chunk[columns]
+
+            valid = check_valid_rows(table_chunk)
+            if not np.all(valid):
+                n_non_predictable += np.sum(valid)
+                table_chunk = table_chunk[valid]
+
+            table.append(table_chunk)
+
+        table = vstack(table)
+        self.log.info("Events read from input: %d", n_events_in_file)
+        self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
 
-        mask = self.classifier.quality_query.get_table_mask(table)
-        table = table[mask]
-        self.log.info("Events after applying quality query: %d", len(table))
         if len(table) == 0:
             raise TooFewEvents(
                 f"No events after quality query for telescope type {telescope_type}"
             )
 
-        table = self.classifier.feature_generator(table, subarray=self.subarray)
-
-        # Add true energy for energy-dependent performance plots
-        columns = self.classifier.features + [self.classifier.target, "true_energy"]
-        table = table[columns]
-
-        valid = check_valid_rows(table)
-        if not np.all(valid):
-            self.log.warning("Dropping non-predictable events.")
-            table = table[valid]
+        if n_non_predictable > 0:
+            self.log.warning("Dropping %d non-predictable events.", n_non_predictable)
 
         if n_events is not None:
             if n_events > len(table):

From 525e3ba9a47cc7b0b32ec1ee566aa115806a0b26 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Thu, 26 Oct 2023 16:14:49 +0200
Subject: [PATCH 05/13] Add changelog

---
 docs/changes/2423.optimization.rst | 3 +++
 1 file changed, 3 insertions(+)
 create mode 100644 docs/changes/2423.optimization.rst

diff --git a/docs/changes/2423.optimization.rst b/docs/changes/2423.optimization.rst
new file mode 100644
index 00000000000..b6e1567767a
--- /dev/null
+++ b/docs/changes/2423.optimization.rst
@@ -0,0 +1,3 @@
+Load data and apply event and column selection in chunks in ``ctapipe-train-*``
+before merging afterwards.
+This reduces memory usage.

From 4ac7709a2139f9bc1fbb01bc28ce328428dcceba Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Mon, 6 Nov 2023 13:02:26 +0100
Subject: [PATCH 06/13] Fix line too long

---
 ctapipe/tools/train_disp_reconstructor.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index 7ba282c9728..ec69a3d13b8 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -151,7 +151,8 @@ def _read_table(self, telescope_type):
             n_valid_events_in_file += len(table_chunk)
 
             if not np.all(
-                table_chunk["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value
+                table_chunk["subarray_pointing_frame"]
+                == CoordinateFrameType.ALTAZ.value
             ):
                 raise ValueError(
                     "Pointing information for training data has to be provided in horizontal coordinates"

From 3fae0b0af16ea8da5033611c4090d587e8ce0dff Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Mon, 6 Nov 2023 14:16:59 +0100
Subject: [PATCH 07/13] Count invalid events as non-predictable, not valid
 events

---
 ctapipe/tools/train_disp_reconstructor.py  | 2 +-
 ctapipe/tools/train_energy_regressor.py    | 2 +-
 ctapipe/tools/train_particle_classifier.py | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index ec69a3d13b8..6b3ad6d0314 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -168,7 +168,7 @@ def _read_table(self, telescope_type):
 
             valid = check_valid_rows(table_chunk)
             if not np.all(valid):
-                n_non_predictable += np.sum(valid)
+                n_non_predictable += np.sum(~valid)
                 table_chunk = table_chunk[valid]
 
             table.append(table_chunk)
diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index 3032e9237b3..7fe7986b380 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -151,7 +151,7 @@ def _read_table(self, telescope_type):
 
             valid = check_valid_rows(table_chunk)
             if not np.all(valid):
-                n_non_predictable += np.sum(valid)
+                n_non_predictable += np.sum(~valid)
                 table_chunk = table_chunk[valid]
 
             table.append(table_chunk)
diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index d3d1d8074db..76865cd5366 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -202,7 +202,7 @@ def _read_table(self, telescope_type, loader, n_events=None):
 
             valid = check_valid_rows(table_chunk)
             if not np.all(valid):
-                n_non_predictable += np.sum(valid)
+                n_non_predictable += np.sum(~valid)
                 table_chunk = table_chunk[valid]
 
             table.append(table_chunk)

From 8378645763482fac1f0239cc81fabc68eb0eaa75 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Tue, 7 Nov 2023 10:42:08 +0100
Subject: [PATCH 08/13] Refactor reading training data into standalone funtion

---
 ctapipe/tools/train_disp_reconstructor.py  | 104 +++++----------------
 ctapipe/tools/train_energy_regressor.py    |  81 +++-------------
 ctapipe/tools/train_particle_classifier.py |  94 +++++--------------
 ctapipe/tools/utils.py                     |  89 ++++++++++++++++++
 4 files changed, 151 insertions(+), 217 deletions(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index 6b3ad6d0314..715e42a3cf5 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -3,15 +3,14 @@
 """
 import astropy.units as u
 import numpy as np
-from astropy.table import vstack
 
-from ctapipe.containers import CoordinateFrameType
 from ctapipe.core import Tool
 from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
-from ctapipe.exceptions import TooFewEvents
 from ctapipe.io import TableLoader
 from ctapipe.reco import CrossValidator, DispReconstructor
-from ctapipe.reco.preprocessing import check_valid_rows, horizontal_to_telescope
+from ctapipe.reco.preprocessing import horizontal_to_telescope
+
+from .utils import read_training_events
 
 __all__ = [
     "TrainDispReconstructor",
@@ -118,7 +117,28 @@ def start(self):
         self.log.info("Training models for %d types", len(types))
         for tel_type in types:
             self.log.info("Loading events for %s", tel_type)
-            table = self._read_table(tel_type)
+            feature_names = self.models.features + [
+                "true_energy",
+                "subarray_pointing_lat",
+                "subarray_pointing_lon",
+                "true_alt",
+                "true_az",
+                "hillas_fov_lat",
+                "hillas_fov_lon",
+                "hillas_psi",
+            ]
+            table = read_training_events(
+                self.loader,
+                self.chunk_size,
+                tel_type,
+                self.models,
+                feature_names,
+                self.log,
+                self.rng,
+                self.n_events.tel[tel_type],
+            )
+            table[self.models.target] = self._get_true_disp(table)
+            table = table[self.models.features + [self.models.target, "true_energy"]]
 
             self.log.info("Train models on %s events", len(table))
             self.cross_validate(tel_type, table)
@@ -127,80 +147,6 @@ def start(self):
             self.models.fit(tel_type, table)
             self.log.info("done")
 
-    def _read_table(self, telescope_type):
-        chunk_iterator = self.loader.read_telescope_events_chunked(
-            self.chunk_size,
-            telescopes=[telescope_type],
-        )
-        table = []
-        n_events_in_file = 0
-        n_valid_events_in_file = 0
-        n_non_predictable = 0
-
-        for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
-            self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
-            n_events_in_file += len(table_chunk)
-
-            mask = self.models.quality_query.get_table_mask(table_chunk)
-            table_chunk = table_chunk[mask]
-            self.log.debug(
-                "Events in chunk %d after applying quality_query: %d",
-                chunk,
-                len(table_chunk),
-            )
-            n_valid_events_in_file += len(table_chunk)
-
-            if not np.all(
-                table_chunk["subarray_pointing_frame"]
-                == CoordinateFrameType.ALTAZ.value
-            ):
-                raise ValueError(
-                    "Pointing information for training data has to be provided in horizontal coordinates"
-                )
-
-            table_chunk = self.models.feature_generator(
-                table_chunk, subarray=self.loader.subarray
-            )
-            table_chunk[self.models.target] = self._get_true_disp(table_chunk)
-            # Add true energy for energy-dependent performance plots
-            columns = self.models.features + [self.models.target, "true_energy"]
-            table_chunk = table_chunk[columns]
-
-            valid = check_valid_rows(table_chunk)
-            if not np.all(valid):
-                n_non_predictable += np.sum(~valid)
-                table_chunk = table_chunk[valid]
-
-            table.append(table_chunk)
-
-        table = vstack(table)
-        self.log.info("Events read from input: %d", n_events_in_file)
-        self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
-
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"No events after quality query for telescope type {telescope_type}"
-            )
-
-        if n_non_predictable > 0:
-            self.log.warning("Dropping %d non-predictable events.", n_non_predictable)
-
-        n_events = self.n_events.tel[telescope_type]
-        if n_events is not None:
-            if n_events > len(table):
-                self.log.warning(
-                    "Number of events in table (%d) is less than requested number of events %d",
-                    len(table),
-                    n_events,
-                )
-            else:
-                self.log.info("Sampling %d events", n_events)
-                idx = self.rng.choice(len(table), n_events, replace=False)
-                idx.sort()
-                table = table[idx]
-
-        return table
-
     def _get_true_disp(self, table):
         fov_lon, fov_lat = horizontal_to_telescope(
             alt=table["true_alt"],
diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index 7fe7986b380..8aa9fd753da 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -2,14 +2,13 @@
 Tool for training the EnergyRegressor
 """
 import numpy as np
-from astropy.table import vstack
 
 from ctapipe.core import Tool
 from ctapipe.core.traits import Int, IntTelescopeParameter, Path
-from ctapipe.exceptions import TooFewEvents
 from ctapipe.io import TableLoader
 from ctapipe.reco import CrossValidator, EnergyRegressor
-from ctapipe.reco.preprocessing import check_valid_rows
+
+from .utils import read_training_events
 
 __all__ = [
     "TrainEnergyRegressor",
@@ -111,7 +110,17 @@ def start(self):
         self.log.info("Training models for %d types", len(types))
         for tel_type in types:
             self.log.info("Loading events for %s", tel_type)
-            table = self._read_table(tel_type)
+            feature_names = self.regressor.features + [self.regressor.target]
+            table = read_training_events(
+                self.loader,
+                self.chunk_size,
+                tel_type,
+                self.regressor,
+                feature_names,
+                self.log,
+                self.rng,
+                self.n_events.tel[tel_type],
+            )
 
             self.log.info("Train on %s events", len(table))
             self.cross_validate(tel_type, table)
@@ -120,70 +129,6 @@ def start(self):
             self.regressor.fit(tel_type, table)
             self.log.info("done")
 
-    def _read_table(self, telescope_type):
-        chunk_iterator = self.loader.read_telescope_events_chunked(
-            self.chunk_size,
-            telescopes=[telescope_type],
-        )
-        table = []
-        n_events_in_file = 0
-        n_valid_events_in_file = 0
-        n_non_predictable = 0
-
-        for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
-            self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
-            n_events_in_file += len(table_chunk)
-
-            mask = self.regressor.quality_query.get_table_mask(table_chunk)
-            table_chunk = table_chunk[mask]
-            self.log.debug(
-                "Events in chunk %d after applying quality_query: %d",
-                chunk,
-                len(table_chunk),
-            )
-            n_valid_events_in_file += len(table_chunk)
-
-            table_chunk = self.regressor.feature_generator(
-                table_chunk, subarray=self.loader.subarray
-            )
-            feature_names = self.regressor.features + [self.regressor.target]
-            table_chunk = table_chunk[feature_names]
-
-            valid = check_valid_rows(table_chunk)
-            if not np.all(valid):
-                n_non_predictable += np.sum(~valid)
-                table_chunk = table_chunk[valid]
-
-            table.append(table_chunk)
-
-        table = vstack(table)
-        self.log.info("Events read from input: %d", n_events_in_file)
-        self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
-
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"No events after quality query for telescope type {telescope_type}"
-            )
-
-        if n_non_predictable > 0:
-            self.log.warning("Dropping %d non-predictable events.", n_non_predictable)
-
-        n_events = self.n_events.tel[telescope_type]
-        if n_events is not None:
-            if n_events > len(table):
-                self.log.warning(
-                    "Number of events in table (%d) is less than requested number of events %d",
-                    len(table),
-                    n_events,
-                )
-            else:
-                self.log.info("Sampling %d events", n_events)
-                idx = self.rng.choice(len(table), n_events, replace=False)
-                idx.sort()
-                table = table[idx]
-
-        return table
-
     def finish(self):
         """
         Write-out trained models and cross-validation results.
diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index 76865cd5366..aa56d8ec723 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -6,10 +6,10 @@
 
 from ctapipe.core.tool import Tool
 from ctapipe.core.traits import Int, IntTelescopeParameter, Path
-from ctapipe.exceptions import TooFewEvents
 from ctapipe.io import TableLoader
 from ctapipe.reco import CrossValidator, ParticleClassifier
-from ctapipe.reco.preprocessing import check_valid_rows
+
+from .utils import read_training_events
 
 __all__ = [
     "TrainParticleClassifier",
@@ -170,76 +170,30 @@ def start(self):
             self.classifier.fit(tel_type, table)
             self.log.info("done")
 
-    def _read_table(self, telescope_type, loader, n_events=None):
-        chunk_iterator = loader.read_telescope_events_chunked(
-            self.chunk_size,
-            telescopes=[telescope_type],
-        )
-        table = []
-        n_events_in_file = 0
-        n_valid_events_in_file = 0
-        n_non_predictable = 0
-
-        for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
-            self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
-            n_events_in_file += len(table_chunk)
-
-            mask = self.classifier.quality_query.get_table_mask(table_chunk)
-            table_chunk = table_chunk[mask]
-            self.log.debug(
-                "Events in chunk %d after applying quality_query: %d",
-                chunk,
-                len(table_chunk),
-            )
-            n_valid_events_in_file += len(table_chunk)
-
-            table_chunk = self.classifier.feature_generator(
-                table_chunk, subarray=self.subarray
-            )
-            # Add true energy for energy-dependent performance plots
-            columns = self.classifier.features + [self.classifier.target, "true_energy"]
-            table_chunk = table_chunk[columns]
-
-            valid = check_valid_rows(table_chunk)
-            if not np.all(valid):
-                n_non_predictable += np.sum(~valid)
-                table_chunk = table_chunk[valid]
-
-            table.append(table_chunk)
-
-        table = vstack(table)
-        self.log.info("Events read from input: %d", n_events_in_file)
-        self.log.info("Events after applying quality query: %d", n_valid_events_in_file)
-
-        if len(table) == 0:
-            raise TooFewEvents(
-                f"No events after quality query for telescope type {telescope_type}"
-            )
-
-        if n_non_predictable > 0:
-            self.log.warning("Dropping %d non-predictable events.", n_non_predictable)
-
-        if n_events is not None:
-            if n_events > len(table):
-                self.log.warning(
-                    "Number of events in table (%d) is less than requested number of events %d",
-                    len(table),
-                    n_events,
-                )
-            else:
-                self.log.info("Sampling %d events", n_events)
-                idx = self.rng.choice(len(table), n_events, replace=False)
-                idx.sort()
-                table = table[idx]
-
-        return table
-
     def _read_input_data(self, tel_type):
-        signal = self._read_table(
-            tel_type, self.signal_loader, self.n_signal.tel[tel_type]
+        feature_names = self.classifier.features + [
+            self.classifier.target,
+            "true_energy",
+        ]
+        signal = read_training_events(
+            self.signal_loader,
+            self.chunk_size,
+            tel_type,
+            self.classifier,
+            feature_names,
+            self.log,
+            self.rng,
+            self.n_signal.tel[tel_type],
         )
-        background = self._read_table(
-            tel_type, self.background_loader, self.n_background.tel[tel_type]
+        background = read_training_events(
+            self.background_loader,
+            self.chunk_size,
+            tel_type,
+            self.classifier,
+            feature_names,
+            self.log,
+            self.rng,
+            self.n_signal.tel[tel_type],
         )
         table = vstack([signal, background])
         self.log.info(
diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py
index 131046fbc1c..b2bd0223953 100644
--- a/ctapipe/tools/utils.py
+++ b/ctapipe/tools/utils.py
@@ -5,6 +5,14 @@
 import sys
 from collections import OrderedDict
 
+import numpy as np
+from astropy.table import vstack
+
+from ctapipe.containers import CoordinateFrameType
+from ctapipe.exceptions import TooFewEvents
+from ctapipe.reco.preprocessing import check_valid_rows
+from ctapipe.reco.sklearn import DispReconstructor
+
 if sys.version_info < (3, 10):
     from importlib_metadata import distribution
 else:
@@ -71,3 +79,84 @@ def get_all_descriptions():
             descriptions[name] = "[no documentation. Please add a docstring]"
 
     return descriptions
+
+
+def read_training_events(
+    loader,
+    chunk_size,
+    telescope_type,
+    reconstructor,
+    feature_names,
+    logger,
+    rng,
+    n_events=None,
+):
+    chunk_iterator = loader.read_telescope_events_chunked(
+        chunk_size,
+        telescopes=[telescope_type],
+    )
+    table = []
+    n_events_in_file = 0
+    n_valid_events_in_file = 0
+    n_non_predictable = 0
+
+    for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
+        logger.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
+        n_events_in_file += len(table_chunk)
+
+        if isinstance(reconstructor, DispReconstructor):
+            if not np.all(
+                table_chunk["subarray_pointing_frame"]
+                == CoordinateFrameType.ALTAZ.value
+            ):
+                raise ValueError(
+                    "Pointing information for training data has to be provided in horizontal coordinates"
+                )
+
+        mask = reconstructor.quality_query.get_table_mask(table_chunk)
+        table_chunk = table_chunk[mask]
+        logger.debug(
+            "Events in chunk %d after applying quality_query: %d",
+            chunk,
+            len(table_chunk),
+        )
+        n_valid_events_in_file += len(table_chunk)
+
+        table_chunk = reconstructor.feature_generator(
+            table_chunk, subarray=loader.subarray
+        )
+        table_chunk = table_chunk[feature_names]
+
+        valid = check_valid_rows(table_chunk)
+        if not np.all(valid):
+            n_non_predictable += np.sum(~valid)
+            table_chunk = table_chunk[valid]
+
+        table.append(table_chunk)
+
+    table = vstack(table)
+    logger.info("Events read from input: %d", n_events_in_file)
+    logger.info("Events after applying quality query: %d", n_valid_events_in_file)
+
+    if len(table) == 0:
+        raise TooFewEvents(
+            f"No events after quality query for telescope type {telescope_type}"
+        )
+
+    if n_non_predictable > 0:
+        logger.warning("Dropping %d non-predictable events.", n_non_predictable)
+
+    if n_events is not None:
+        if n_events > len(table):
+            logger.warning(
+                "Number of events in table (%d) is less than requested number of events %d",
+                len(table),
+                n_events,
+            )
+        else:
+            logger.info("Sampling %d events", n_events)
+            idx = rng.choice(len(table), n_events, replace=False)
+            idx.sort()
+            table = table[idx]
+
+    return table

From a3891834a884db5b864b56491c4e099ce8d742fa Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Fri, 10 Nov 2023 12:20:59 +0100
Subject: [PATCH 09/13] Add docstring

---
 ctapipe/tools/utils.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py
index b2bd0223953..be946e6be78 100644
--- a/ctapipe/tools/utils.py
+++ b/ctapipe/tools/utils.py
@@ -91,6 +91,7 @@ def read_training_events(
     rng,
     n_events=None,
 ):
+    """Chunked loading of events for training ML models"""
     chunk_iterator = loader.read_telescope_events_chunked(
         chunk_size,
         telescopes=[telescope_type],

From b52be8427321e65cf30e5400ff484359b59d7de5 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Tue, 14 Nov 2023 10:23:40 +0100
Subject: [PATCH 10/13] Add type hints

---
 ctapipe/tools/utils.py | 20 ++++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py
index be946e6be78..379b6b0bde8 100644
--- a/ctapipe/tools/utils.py
+++ b/ctapipe/tools/utils.py
@@ -4,14 +4,18 @@
 import importlib
 import sys
 from collections import OrderedDict
+from logging import Logger
 
 import numpy as np
 from astropy.table import vstack
 
 from ctapipe.containers import CoordinateFrameType
+from ctapipe.core.traits import Int
 from ctapipe.exceptions import TooFewEvents
+from ctapipe.instrument.telescope import TelescopeDescription
+from ctapipe.io import TableLoader
 from ctapipe.reco.preprocessing import check_valid_rows
-from ctapipe.reco.sklearn import DispReconstructor
+from ctapipe.reco.sklearn import DispReconstructor, SKLearnReconstructor
 
 if sys.version_info < (3, 10):
     from importlib_metadata import distribution
@@ -82,13 +86,13 @@ def get_all_descriptions():
 
 
 def read_training_events(
-    loader,
-    chunk_size,
-    telescope_type,
-    reconstructor,
-    feature_names,
-    logger,
-    rng,
+    loader: TableLoader,
+    chunk_size: Int,
+    telescope_type: TelescopeDescription,
+    reconstructor: SKLearnReconstructor,
+    feature_names: list,
+    logger: Logger,
+    rng: np.random.Generator,
     n_events=None,
 ):
     """Chunked loading of events for training ML models"""

From f93e2459ce91c9211fae1043a3df5894b2a8bff7 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Tue, 14 Nov 2023 11:24:26 +0100
Subject: [PATCH 11/13] Use relative imports, default logger, and correct type
 hint

---
 ctapipe/tools/train_disp_reconstructor.py  |  2 +-
 ctapipe/tools/train_energy_regressor.py    |  2 +-
 ctapipe/tools/train_particle_classifier.py |  4 +-
 ctapipe/tools/utils.py                     | 43 ++++++++++++----------
 4 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index 715e42a3cf5..72a551e8597 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -133,8 +133,8 @@ def start(self):
                 tel_type,
                 self.models,
                 feature_names,
-                self.log,
                 self.rng,
+                self.log,
                 self.n_events.tel[tel_type],
             )
             table[self.models.target] = self._get_true_disp(table)
diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index 8aa9fd753da..9db47e1161e 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -117,8 +117,8 @@ def start(self):
                 tel_type,
                 self.regressor,
                 feature_names,
-                self.log,
                 self.rng,
+                self.log,
                 self.n_events.tel[tel_type],
             )
 
diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index aa56d8ec723..59656d8441c 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -181,8 +181,8 @@ def _read_input_data(self, tel_type):
             tel_type,
             self.classifier,
             feature_names,
-            self.log,
             self.rng,
+            self.log,
             self.n_signal.tel[tel_type],
         )
         background = read_training_events(
@@ -191,8 +191,8 @@ def _read_input_data(self, tel_type):
             tel_type,
             self.classifier,
             feature_names,
-            self.log,
             self.rng,
+            self.log,
             self.n_signal.tel[tel_type],
         )
         table = vstack([signal, background])
diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py
index 379b6b0bde8..9e7c76eb837 100644
--- a/ctapipe/tools/utils.py
+++ b/ctapipe/tools/utils.py
@@ -2,20 +2,23 @@
 """Utils to create scripts and command-line tools"""
 import argparse
 import importlib
+import logging
 import sys
 from collections import OrderedDict
-from logging import Logger
+from typing import Type
 
 import numpy as np
 from astropy.table import vstack
 
-from ctapipe.containers import CoordinateFrameType
-from ctapipe.core.traits import Int
-from ctapipe.exceptions import TooFewEvents
-from ctapipe.instrument.telescope import TelescopeDescription
-from ctapipe.io import TableLoader
-from ctapipe.reco.preprocessing import check_valid_rows
-from ctapipe.reco.sklearn import DispReconstructor, SKLearnReconstructor
+from ..containers import CoordinateFrameType
+from ..core.traits import Int
+from ..exceptions import TooFewEvents
+from ..instrument.telescope import TelescopeDescription
+from ..io import TableLoader
+from ..reco.preprocessing import check_valid_rows
+from ..reco.sklearn import DispReconstructor, SKLearnReconstructor
+
+LOG = logging.getLogger(__name__)
 
 if sys.version_info < (3, 10):
     from importlib_metadata import distribution
@@ -89,10 +92,10 @@ def read_training_events(
     loader: TableLoader,
     chunk_size: Int,
     telescope_type: TelescopeDescription,
-    reconstructor: SKLearnReconstructor,
+    reconstructor: Type[SKLearnReconstructor],
     feature_names: list,
-    logger: Logger,
     rng: np.random.Generator,
+    log=LOG,
     n_events=None,
 ):
     """Chunked loading of events for training ML models"""
@@ -106,7 +109,7 @@ def read_training_events(
     n_non_predictable = 0
 
     for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
-        logger.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
+        log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
         n_events_in_file += len(table_chunk)
 
         if isinstance(reconstructor, DispReconstructor):
@@ -115,12 +118,13 @@ def read_training_events(
                 == CoordinateFrameType.ALTAZ.value
             ):
                 raise ValueError(
-                    "Pointing information for training data has to be provided in horizontal coordinates"
+                    "Pointing information for training data"
+                    " has to be provided in horizontal coordinates"
                 )
 
         mask = reconstructor.quality_query.get_table_mask(table_chunk)
         table_chunk = table_chunk[mask]
-        logger.debug(
+        log.debug(
             "Events in chunk %d after applying quality_query: %d",
             chunk,
             len(table_chunk),
@@ -140,8 +144,8 @@ def read_training_events(
         table.append(table_chunk)
 
     table = vstack(table)
-    logger.info("Events read from input: %d", n_events_in_file)
-    logger.info("Events after applying quality query: %d", n_valid_events_in_file)
+    log.info("Events read from input: %d", n_events_in_file)
+    log.info("Events after applying quality query: %d", n_valid_events_in_file)
 
     if len(table) == 0:
         raise TooFewEvents(
@@ -149,17 +153,18 @@ def read_training_events(
         )
 
     if n_non_predictable > 0:
-        logger.warning("Dropping %d non-predictable events.", n_non_predictable)
+        log.warning("Dropping %d non-predictable events.", n_non_predictable)
 
     if n_events is not None:
         if n_events > len(table):
-            logger.warning(
-                "Number of events in table (%d) is less than requested number of events %d",
+            log.warning(
+                "Number of events in table (%d) is less"
+                " than requested number of events %d",
                 len(table),
                 n_events,
             )
         else:
-            logger.info("Sampling %d events", n_events)
+            log.info("Sampling %d events", n_events)
             idx = rng.choice(len(table), n_events, replace=False)
             idx.sort()
             table = table[idx]

From b61576e2ac2104880355f4590e1b236d98d2b999 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Tue, 14 Nov 2023 12:15:23 +0100
Subject: [PATCH 12/13] Use keyword arguments

---
 ctapipe/tools/train_disp_reconstructor.py  | 16 +++++------
 ctapipe/tools/train_energy_regressor.py    | 16 +++++------
 ctapipe/tools/train_particle_classifier.py | 32 +++++++++++-----------
 3 files changed, 32 insertions(+), 32 deletions(-)

diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py
index 72a551e8597..a188b4260d0 100644
--- a/ctapipe/tools/train_disp_reconstructor.py
+++ b/ctapipe/tools/train_disp_reconstructor.py
@@ -128,14 +128,14 @@ def start(self):
                 "hillas_psi",
             ]
             table = read_training_events(
-                self.loader,
-                self.chunk_size,
-                tel_type,
-                self.models,
-                feature_names,
-                self.rng,
-                self.log,
-                self.n_events.tel[tel_type],
+                loader=self.loader,
+                chunk_size=self.chunk_size,
+                telescope_type=tel_type,
+                reconstructor=self.models,
+                feature_names=feature_names,
+                rng=self.rng,
+                log=self.log,
+                n_events=self.n_events.tel[tel_type],
             )
             table[self.models.target] = self._get_true_disp(table)
             table = table[self.models.features + [self.models.target, "true_energy"]]
diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py
index 9db47e1161e..2c9d411f594 100644
--- a/ctapipe/tools/train_energy_regressor.py
+++ b/ctapipe/tools/train_energy_regressor.py
@@ -112,14 +112,14 @@ def start(self):
             self.log.info("Loading events for %s", tel_type)
             feature_names = self.regressor.features + [self.regressor.target]
             table = read_training_events(
-                self.loader,
-                self.chunk_size,
-                tel_type,
-                self.regressor,
-                feature_names,
-                self.rng,
-                self.log,
-                self.n_events.tel[tel_type],
+                loader=self.loader,
+                chunk_size=self.chunk_size,
+                telescope_type=tel_type,
+                reconstructor=self.regressor,
+                feature_names=feature_names,
+                rng=self.rng,
+                log=self.log,
+                n_events=self.n_events.tel[tel_type],
             )
 
             self.log.info("Train on %s events", len(table))
diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index 59656d8441c..90324b5d5b8 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -176,24 +176,24 @@ def _read_input_data(self, tel_type):
             "true_energy",
         ]
         signal = read_training_events(
-            self.signal_loader,
-            self.chunk_size,
-            tel_type,
-            self.classifier,
-            feature_names,
-            self.rng,
-            self.log,
-            self.n_signal.tel[tel_type],
+            loader=self.signal_loader,
+            chunk_size=self.chunk_size,
+            telescope_type=tel_type,
+            reconstructor=self.classifier,
+            feature_names=feature_names,
+            rng=self.rng,
+            log=self.log,
+            n_events=self.n_signal.tel[tel_type],
         )
         background = read_training_events(
-            self.background_loader,
-            self.chunk_size,
-            tel_type,
-            self.classifier,
-            feature_names,
-            self.rng,
-            self.log,
-            self.n_signal.tel[tel_type],
+            loader=self.background_loader,
+            chunk_size=self.chunk_size,
+            telescope_type=tel_type,
+            reconstructor=self.classifier,
+            feature_names=feature_names,
+            rng=self.rng,
+            log=self.log,
+            n_events=self.n_signal.tel[tel_type],
         )
         table = vstack([signal, background])
         self.log.info(

From bb01357995f58fc742ead3162eb2e1e10f98e087 Mon Sep 17 00:00:00 2001
From: Lukas Beiske <lukas.beiske@tu-dortmund.de>
Date: Tue, 14 Nov 2023 14:08:59 +0100
Subject: [PATCH 13/13] Use n_background for background events

---
 ctapipe/tools/train_particle_classifier.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py
index 90324b5d5b8..70c53b63419 100644
--- a/ctapipe/tools/train_particle_classifier.py
+++ b/ctapipe/tools/train_particle_classifier.py
@@ -193,7 +193,7 @@ def _read_input_data(self, tel_type):
             feature_names=feature_names,
             rng=self.rng,
             log=self.log,
-            n_events=self.n_signal.tel[tel_type],
+            n_events=self.n_background.tel[tel_type],
         )
         table = vstack([signal, background])
         self.log.info(