Skip to content

Commit

Permalink
911 Generalize correction application in cities
Browse files Browse the repository at this point in the history
#911

[author: gonzaponte]

This PR modifies the cities to apply corrections in a more general
way. Currently, the correction strategy is hard-wired to
`NormStrategy.kr` which limits our capabilities. This PR extends the
method to allow different correction strategies.

[reviewer: jwaiton]

This PR introduces methods for correction beyond the standard
`NormStrategy.kr` including custom correction methods. The code is
well documented and works as intended. Good job!
  • Loading branch information
jwaiton authored and carhc committed Nov 6, 2024
2 parents 237ab2e + 96444fc commit bec62bd
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 51 deletions.
30 changes: 17 additions & 13 deletions invisible_cities/cities/beersheba.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@
from . components import hits_corrector
from . components import hits_thresholder
from . components import hits_and_kdst_from_files
from . components import identity

from .. core.configure import EventRangeType
from .. core.configure import OneOrManyFiles
from .. core.configure import check_annotations

from .. core import tbl_functions as tbl
from .. dataflow import dataflow as fl
from .. core import system_of_units as units
from .. core import tbl_functions as tbl
from .. dataflow import dataflow as fl

from .. dataflow.dataflow import push
from .. dataflow.dataflow import pipe
Expand All @@ -86,7 +86,6 @@
from .. types.symbols import CutType
from .. types.symbols import DeconvolutionMode

from .. core import system_of_units as units

from typing import Tuple
from typing import List
Expand Down Expand Up @@ -395,8 +394,7 @@ def beersheba( files_in : OneOrManyFiles
, same_peak : bool
, deconv_params : dict
, satellite_params : Union[dict, NoneType]
, corrections_file : Union[ str, NoneType]
, apply_temp : Union[bool, NoneType]
, corrections : dict
):
"""
The city corrects Penthesilea hits energy and extracts topology information.
Expand Down Expand Up @@ -470,6 +468,16 @@ def beersheba( files_in : OneOrManyFiles
`abs`: cut on the absolute value of the hits.
`rel`: cut on the relative value (to the max) of the hits.
corrections : dict
filename : str
Path to the file holding the correction maps
apply_temp : bool
Whether to apply temporal corrections
norm_strat : NormStrategy
Normalization strategy
norm_value : float, optional
Normalization value in case of `norm_strat = NormStrategy.custom`
----------
Input
----------
Expand All @@ -480,13 +488,9 @@ def beersheba( files_in : OneOrManyFiles
DECO : Deconvolved hits table
MC info : (if run number <=0)
"""

if corrections_file is None: correct_hits = identity
else : correct_hits = hits_corrector(corrections_file, apply_temp)
correct_hits = fl.map( correct_hits, item="hits")

threshold_hits = fl.map(hits_thresholder(threshold, same_peak), item="hits")
hitc_to_df = fl.map(hitc_to_df_, item="hits")
correct_hits = fl.map(hits_corrector(**corrections), item="hits")
threshold_hits = fl.map(hits_thresholder(threshold, same_peak), item="hits")
hitc_to_df = fl.map(hitc_to_df_, item="hits")

deconv_params['psf_fname' ] = expandvars(deconv_params['psf_fname'])
deconv_params['satellite_params'] = satellite_params
Expand Down
26 changes: 20 additions & 6 deletions invisible_cities/cities/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Tuple
from typing import Union
from typing import Any
from typing import Optional

import tables as tb
import numpy as np
Expand Down Expand Up @@ -1514,7 +1515,11 @@ def threshold_hits(hitc : HitCollection) -> HitCollection:


@check_annotations
def hits_corrector(map_fname : str, apply_temp : bool) -> Callable:
def hits_corrector( filename : str
, apply_temp : bool
, norm_strat : NormStrategy
, norm_value : Optional[Union[float, NoneType]] = None
) -> Callable:
"""
Applies energy correction map and converts drift time to z.
Expand All @@ -1531,11 +1536,20 @@ def hits_corrector(map_fname : str, apply_temp : bool) -> Callable:
A function that takes a HitCollection as input and returns
the same object with modified Ec and Z fields.
"""
map_fname = os.path.expandvars(map_fname)
maps = read_maps(map_fname)
get_coef = apply_all_correction(maps, apply_temp = apply_temp, norm_strat = NormStrategy.kr)
time_to_Z = (get_df_to_z_converter(maps) if maps.t_evol is not None else
lambda x: x)

if ( ((norm_strat is not NormStrategy.custom) ^ (norm_value is None)) or
(norm_strat is NormStrategy.custom) and (norm_value<= 0)):
raise ValueError(
"`NormStrategy.custom` requires `norm_value` to be greater than 0. "
"For all other `NormStrategy` options, `norm_value` must not be provided."
)

maps = read_maps(os.path.expandvars(filename))
get_coef = apply_all_correction( maps
, apply_temp = apply_temp
, norm_strat = norm_strat
, norm_value = norm_value)
time_to_Z = get_df_to_z_converter(maps) if maps.t_evol is not None else identity

def correct(hitc : HitCollection) -> HitCollection:
for hit in hitc.hits:
Expand Down
60 changes: 60 additions & 0 deletions invisible_cities/cities/components_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from .. core.exceptions import NoInputFiles
from .. core.testing_utils import assert_tables_equality
from .. core import system_of_units as units
from .. evm.event_model import Cluster
from .. evm.event_model import Hit
from .. evm.event_model import HitCollection
from .. types.ic_types import xy
from .. types.symbols import WfType
from .. types.symbols import EventRange as ER
from .. types.symbols import NormStrategy
from .. types.symbols import XYReco

from . components import event_range
Expand All @@ -32,6 +37,7 @@
from . components import mcsensors_from_file
from . components import create_timestamp
from . components import check_max_time
from . components import hits_corrector
from . components import write_city_configuration
from . components import copy_cities_configuration

Expand Down Expand Up @@ -465,6 +471,60 @@ def test_read_wrong_pmt_ids(ICDATADIR):
next(sns_gen)


@mark.parametrize( "norm_strat norm_value".split(),
( (NormStrategy.kr , None) # None marks the default value
, (NormStrategy.max , None)
, (NormStrategy.mean , None)
, (NormStrategy.custom, 1e3)
))
@mark.parametrize("apply_temp", (False, True))
def test_hits_corrector_valid_normalization_options( correction_map_filename
, norm_strat
, norm_value
, apply_temp ):
"""
Test that all valid normalization options work to some
extent. Here we just check that the values make some sense: not
nan and greater than 0. The more exhaustive tests are performed
directly on the core functions.
"""
n = 50
xs = np.random.uniform(-10, 10, n)
ys = np.random.uniform(-10, 10, n)
zs = np.random.uniform( 10, 50, n)

hits = []
for i, x, y, z in zip(range(n), xs, ys, zs):
c = Cluster(0, xy(x, y), xy.zero(), 1)
h = Hit(i, c, z, 1, xy.zero(), 0)
hits.append(h)

hc = HitCollection(0, 1, hits)

correct = hits_corrector(correction_map_filename, apply_temp, norm_strat, norm_value)
corrected_e = np.array([h.Ec for h in correct(hc).hits])

assert not np.any(np.isnan(corrected_e) )
assert np.all( corrected_e>0)


@mark.parametrize( "norm_strat norm_value".split(),
( (NormStrategy.kr , 0) # 0 doens't count as "not given"
, (NormStrategy.max , 0)
, (NormStrategy.mean , 0)
, (NormStrategy.kr , 1) # any other value must not be given either
, (NormStrategy.max , 1)
, (NormStrategy.mean , 1)
, (NormStrategy.custom, None) # with custom, `norm_value` must be given ...
, (NormStrategy.custom, 0) # ... but not 0
))
def test_hits_corrector_invalid_normalization_options_raises( correction_map_filename
, norm_strat
, norm_value):
with raises(ValueError):
hits_corrector(correction_map_filename, False, norm_strat, norm_value)


def test_write_city_configuration(config_tmpdir):
filename = os.path.join(config_tmpdir, "test_write_configuration.h5")
city_name = "acity"
Expand Down
27 changes: 13 additions & 14 deletions invisible_cities/cities/esmeralda.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,11 @@
from . components import hits_corrector
from . components import hits_thresholder
from . components import compute_and_write_tracks_info
from . components import identity

from .. io. hits_io import hits_writer
from .. io. kdst_io import kdst_from_df_writer
from .. io.run_and_event_io import run_and_event_writer

from .. types.ic_types import NoneType

from typing import Union


def hit_dropper(radius : float):
def in_fiducial(hit : evm.Hit) -> bool:
Expand All @@ -84,8 +79,7 @@ def esmeralda( files_in : OneOrManyFiles
, same_peak : bool
, fiducial_r : float
, paolina_params : dict
, corrections_file : Union[ str, NoneType]
, apply_temp : Union[bool, NoneType]
, corrections : dict
):
"""
The city applies a threshold to sipm hits and extracts
Expand Down Expand Up @@ -125,10 +119,16 @@ def esmeralda( files_in : OneOrManyFiles
radius of blob
max_num_hits : int
maximum number of hits allowed per event to run paolina functions.
corrections_file : str
path to the corrections file
apply_temp : bool
whether to apply temporal corrections
corrections : dict
filename : str
Path to the file holding the correction maps
apply_temp : bool
Whether to apply temporal corrections
norm_strat : NormStrategy
Normalization strategy
norm_value : float, optional
Normalization value in case of `norm_strat = NormStrategy.custom`
Input
----------
Expand All @@ -147,9 +147,8 @@ def esmeralda( files_in : OneOrManyFiles
- Summary/events - summary of per event information
- DST/Events - kdst information
"""
if corrections_file is None: correct_hits = identity
else : correct_hits = hits_corrector(corrections_file, apply_temp)
correct_hits = fl.map( correct_hits, item="hits")
correct_hits = hits_corrector(**corrections)
correct_hits = fl.map(correct_hits, item="hits")
drop_external_hits = fl.map(hit_dropper(fiducial_r), item="hits")
threshold_hits = fl.map(hits_thresholder(threshold, same_peak), item="hits")
event_count_in = fl.spy_count()
Expand Down
51 changes: 46 additions & 5 deletions invisible_cities/cities/sophronia.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

from .. dataflow import dataflow as df


from .. types.symbols import RebinMethod
from .. types.symbols import SiPMCharge
from .. types.symbols import XYReco
Expand Down Expand Up @@ -92,10 +91,52 @@ def sophronia( files_in : OneOrManyFiles
, q_thr : float
, sipm_charge_type : SiPMCharge
, same_peak : bool
, corrections_file : Optional[str] = None
, apply_temp : Optional[bool] = None
, corrections : Optional[dict] = None
):

"""
drift_v : float
Drift velocity
s1_params : dict
Selection criteria for S1 peaks
s2_params : dict
Selection criteria for S2 peaks
global_reco_algo : XYReco
Reconstruction algorithm to use
global_reco_params : dict
Configuration parameters of the given reconstruction algorithm
rebin : int, float
If `rebin_method` is `stride`, it is interpreted as the number
of consecutive slices to accumulate. Otherwise, if
`rebin_method` is `threshold`, it is interpreted as the amount
of accumulated charge necessary to stop the resampling.
rebin_method : RebinMethod
Resampling method to use: `stride` or `threshold`
q_thr : float
Threshold to be applied to each (resampled) slice of every SiPM.
sipm_charge_type : SiPMCharge
Interpretation of the SiPM charge: `raw` or `signal_to_noise`
same_peak : bool
Whether to reassign NN hits' energy only to the hits from the same peak
corrections : dict
filename : str
Path to the file holding the correction maps
apply_temp : bool
Whether to apply temporal corrections
norm_strat : NormStrategy
Normalization strategy
norm_value : float, optional
Normalization value in case of `norm_strat = NormStrategy.custom`
"""
global_reco = compute_xy_position( detector_db
, run_number
, global_reco_algo
Expand Down Expand Up @@ -133,7 +174,7 @@ def sophronia( files_in : OneOrManyFiles
merge_nn_hits = df.map( hits_merger(same_peak)
, item = "hits")

correct_hits = df.map( hits_corrector(corrections_file, apply_temp) if corrections_file is not None else identity
correct_hits = df.map( hits_corrector(**corrections) if corrections is not None else identity
, item = "hits")

build_pointlike_event = df.map( pointlike_event_builder( detector_db
Expand Down
8 changes: 5 additions & 3 deletions invisible_cities/config/beersheba.conf
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ deconv_params = dict(
cut_type = abs,
inter_method = cubic)

corrections_file = "$ICDIR/database/test_data/kr_emap_xy_100_100_r_6573_time.h5"
apply_temp = False

# satellite_params = dict(satellite_start_iter = 75,
# satellite_max_size = 3,
# e_cut = 0.2,
# cut_type = CutType.abs)
satellite_params = None

corrections = dict(
filename = "$ICDIR/database/test_data/kr_emap_xy_100_100_r_6573_time.h5",
apply_temp = False,
norm_strat = kr)
6 changes: 4 additions & 2 deletions invisible_cities/config/esmeralda.conf
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ paolina_params = dict(
blob_radius = 21 * mm,
max_num_hits = 10000)

corrections_file = "$ICDIR/database/test_data/kr_emap_xy_100_100_r_6573_time.h5"
apply_temp = False
corrections = dict(
filename = "$ICDIR/database/test_data/kr_emap_xy_100_100_r_6573_time.h5",
apply_temp = False,
norm_strat = kr)
6 changes: 4 additions & 2 deletions invisible_cities/config/sophronia.conf
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,7 @@ global_reco_params = dict(Qthr = 1 * pes)

same_peak = True

corrections_file = "$ICDIR/database/test_data/kr_emap_xy_100_100_r_6573_time.h5"
apply_temp = True
corrections = dict(
filename = "$ICDIR/database/test_data/kr_emap_xy_100_100_r_6573_time.h5",
apply_temp = True,
norm_strat = kr)
Loading

0 comments on commit bec62bd

Please sign in to comment.