Skip to content

Commit

Permalink
Add back in the changes from gwastro#4603
Browse files Browse the repository at this point in the history
  • Loading branch information
GarethCabournDavies committed Nov 1, 2024
1 parent 7b95712 commit 1e0bd5c
Showing 1 changed file with 89 additions and 61 deletions.
150 changes: 89 additions & 61 deletions pycbc/events/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,22 +923,16 @@ def __init__(self, sngl_ranking, files=None, ifos=None, **kwargs):

if self.kwargs['dq']:
# Reweight the noise rate by the dq reweighting factor
self.dq_val_by_time = {}
self.dq_bin_by_id = {}
for k in self.files.keys():
parsed_attrs = k.split('-')
if len(parsed_attrs) < 3:
continue
if parsed_attrs[2] == 'dq_ts_reference':
ifo = parsed_attrs[0]
dq_type = parsed_attrs[1]
dq_vals = self.assign_dq_val(k)
dq_bins = self.assign_bin_id(k)
if ifo not in self.dq_val_by_time:
self.dq_val_by_time[ifo] = {}
self.dq_bin_by_id[ifo] = {}
self.dq_val_by_time[ifo][dq_type] = dq_vals
self.dq_bin_by_id[ifo][dq_type] = dq_bins
self.dq_rates_by_state = {}
self.dq_bin_by_tid = {}
self.dq_state_segments = {}

for ifo in self.ifos:
key = f'{ifo}-dq_stat_info'
if key in self.files.keys():
self.dq_rates_by_state[ifo] = self.assign_dq_rates(key)
self.dq_bin_by_tid[ifo] = self.assign_template_bins(key)
self.dq_state_segments[ifo] = self.setup_segments(key)

if self.kwargs['chirp_mass']:
# Reweight the signal rate by the chirp mass of the template
Expand All @@ -956,10 +950,9 @@ def __init__(self, sngl_ranking, files=None, ifos=None, **kwargs):
for kname in self.kde_names:
self.assign_kdes(kname)

def assign_bin_id(self, key):
def assign_template_bins(self, key):
"""
Assign bin ID values for DQ reweighting
Assign bin ID values
Assign each template id to a bin name based on a
referenced statistic file.
Expand All @@ -975,18 +968,18 @@ def assign_bin_id(self, key):
"""
ifo = key.split('-')[0]
with h5py.File(self.files[key], 'r') as dq_file:
bin_names = dq_file.attrs['names'][:]
locs = []
names = []
for bin_name in bin_names:
bin_locs = dq_file[ifo + '/locs/' + bin_name][:]
locs = list(locs) + list(bin_locs.astype(int))
names = list(names) + list([bin_name] * len(bin_locs))

bin_dict = dict(zip(locs, names))
tids = []
bin_nums = []
bin_grp = dq_file[f'{ifo}/bins']
for bin_name in bin_grp.keys():
bin_tids = bin_grp[f'{bin_name}/tids'][:]
tids = list(tids) + list(bin_tids.astype(int))
bin_nums = list(bin_nums) + list([bin_name] * len(bin_tids))

bin_dict = dict(zip(tids, bin_nums))
return bin_dict

def assign_dq_val(self, key):
def assign_dq_rates(self, key):
"""
Assign dq values to each time for every bin based on a
referenced statistic file.
Expand All @@ -1005,50 +998,73 @@ def assign_dq_val(self, key):
"""
ifo = key.split('-')[0]
with h5py.File(self.files[key], 'r') as dq_file:
times = dq_file[ifo + '/times'][:]
bin_names = dq_file.attrs['names'][:]
bin_grp = dq_file[f'{ifo}/bins']
dq_dict = {}
for bin_name in bin_names:
dq_vals = dq_file[ifo + '/dq_vals/' + bin_name][:]
dq_dict[bin_name] = dict(zip(times, dq_vals))
for bin_name in bin_grp.keys():
dq_dict[bin_name] = bin_grp[f'{bin_name}/dq_rates'][:]

return dq_dict

def find_dq_val(self, trigs):
"""Get dq values for a specific ifo and times
def setup_segments(self, key):
"""
Check if segments definitions are in stat file
If they are, we are running offline and need to store them
If they aren't, we are running online
"""
ifo = key.split('-')[0]
with h5py.File(self.files[key], 'r') as dq_file:
ifo_grp = dq_file[ifo]
dq_state_segs_dict = {}
for k in ifo_grp['dq_segments'].keys():
seg_dict = {}
seg_dict['start'] = \
ifo_grp[f'dq_segments/{k}/segment_starts'][:]
seg_dict['end'] = \
ifo_grp[f'dq_segments/{k}/segment_ends'][:]
dq_state_segs_dict[k] = seg_dict

Parameters
----------
trigs: ReadByTempate or SingleDetTriggers object
Object containing information about the triggers to be
checked
return dq_state_segs_dict

def find_dq_noise_rate(self, trigs, dq_state):
"""Get dq values for a specific ifo and dq states"""

Returns
-------
dq_val: numpy array
The value of the dq reweighting factor for each trigger
"""
time = trigs['end_time'].astype(int)
try:
tnum = trigs.template_num
ifo = trigs.ifo
except AttributeError:
tnum = trigs['template_id']
assert len(self.ifos) == 1

try:
ifo = trigs.ifo
except AttributeError:
ifo = trigs['ifo']
assert len(numpy.unique(ifo)) == 1
# Should be exactly one ifo provided
ifo = self.ifos[0]
dq_val = numpy.zeros(len(time))
if ifo in self.dq_val_by_time:
for (i, t) in enumerate(time):
for k in self.dq_val_by_time[ifo].keys():
if isinstance(tnum, numpy.ndarray):
bin_name = self.dq_bin_by_id[ifo][k][tnum[i]]
else:
bin_name = self.dq_bin_by_id[ifo][k][tnum]
val = self.dq_val_by_time[ifo][k][bin_name][int(t)]
dq_val[i] = max(dq_val[i], val)
ifo = ifo[0]

dq_val = numpy.zeros(len(dq_state))

if ifo in self.dq_rates_by_state:
for (i, st) in enumerate(dq_state):
if isinstance(tnum, numpy.ndarray):
bin_name = self.dq_bin_by_tid[ifo][tnum[i]]
else:
bin_name = self.dq_bin_by_tid[ifo][tnum]
dq_val[i] = self.dq_rates_by_state[ifo][bin_name][st]
return dq_val

def find_dq_state_by_time(self, ifo, times):
"""Get the dq state for an ifo at times"""
dq_state = numpy.zeros(len(times), dtype=numpy.uint8)
if ifo in self.dq_state_segments:
from pycbc.events.veto import indices_within_times
for k in self.dq_state_segments[ifo]:
starts = self.dq_state_segments[ifo][k]['start']
ends = self.dq_state_segments[ifo][k]['end']
inds = indices_within_times(times, starts, ends)
# states are named in file as 'dq_state_N', need to extract N
dq_state[inds] = int(k[9:])
return dq_state

def assign_fits(self, ifo):
"""
Extract fits from single-detector rate fit files
Expand Down Expand Up @@ -1252,7 +1268,19 @@ def lognoiserate(self, trigs):

if self.kwargs['dq']:
# Reweight the lognoiserate things by the dq reweighting factor
lognoisel += self.find_dq_val(trigs)
# make sure every trig has a dq state
try:
ifo = trigs.ifo
except AttributeError:
ifo = trigs['ifo']
assert len(numpy.unique(ifo)) == 1
# Should be exactly one ifo provided
ifo = ifo[0]

dq_state = self.find_dq_state_by_time(ifo, trigs['end_time'][:])
dq_rate = self.find_dq_noise_rate(trigs, dq_state)
dq_rate = numpy.maximum(dq_rate, 1)
lognoisel += numpy.log(dq_rate)

return numpy.array(lognoisel, ndmin=1, dtype=numpy.float32)

Expand Down

0 comments on commit 1e0bd5c

Please sign in to comment.