Skip to content

Commit

Permalink
Speed up background_bin_from_string (#4475)
Browse files Browse the repository at this point in the history
* Cache duration code, reduce duplication

* Make actually work

* Titos comment

* CC
  • Loading branch information
spxiwh authored Sep 6, 2023
1 parent 0c9a8c3 commit c1ae0d5
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions pycbc/events/coinc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
from .eventmgr_cython import timecoincidence_findidxlen
from .eventmgr_cython import timecluster_cython

# Mapping used in background_bin_from_string to select approximant for
# duration function, if duration-based binning is used.
_APPROXIMANT_DURATION_MAP = {
'SEOBNRv2duration': 'SEOBNRv2',
'SEOBNRv4duration': 'SEOBNRv4',
'SEOBNRv5duration': 'SEOBNRv5_ROM'
}


def background_bin_from_string(background_bins, data):
""" Return template ids for each bin as defined by the format string
Expand All @@ -55,6 +63,9 @@ def background_bin_from_string(background_bins, data):
"""
used = numpy.array([], dtype=numpy.uint32)
bins = {}
# Some duration/peak frequency functions are expensive.
# Do not want to recompute many times, if using lots of bins.
cached_values = {}
for mbin in background_bins:
locs = None
name, bin_type_list, boundary_list = tuple(mbin.split(':'))
Expand All @@ -71,7 +82,9 @@ def background_bin_from_string(background_bins, data):
raise RuntimeError("Can't parse boundary condition! Must begin "
"with 'lt' or 'gt'")

if bin_type == 'component' and boundary[0:2] == 'lt':
if bin_type in cached_values:
vals = cached_values[bin_type]
elif bin_type == 'component' and boundary[0:2] == 'lt':
# maximum component mass is less than boundary value
vals = numpy.maximum(data['mass1'], data['mass2'])
elif bin_type == 'component' and boundary[0:2] == 'gt':
Expand All @@ -91,34 +104,29 @@ def background_bin_from_string(background_bins, data):
elif bin_type == 'chi_eff':
vals = pycbc.conversions.chi_eff(data['mass1'], data['mass2'],
data['spin1z'], data['spin2z'])
elif bin_type == 'SEOBNRv2Peak':
vals = pycbc.pnutils.get_freq('fSEOBNRv2Peak',
data['mass1'], data['mass2'],
data['spin1z'], data['spin2z'])
elif bin_type == 'SEOBNRv4Peak':
vals = pycbc.pnutils.get_freq('fSEOBNRv4Peak', data['mass1'],
data['mass2'], data['spin1z'],
data['spin2z'])
elif bin_type == 'SEOBNRv2duration':
vals = pycbc.pnutils.get_imr_duration(
data['mass1'], data['mass2'],
data['spin1z'], data['spin2z'],
data['f_lower'], approximant='SEOBNRv2')
elif bin_type == 'SEOBNRv4duration':
vals = pycbc.pnutils.get_imr_duration(
data['mass1'][:], data['mass2'][:],
data['spin1z'][:], data['spin2z'][:],
data['f_lower'][:], approximant='SEOBNRv4')
elif bin_type == 'SEOBNRv5duration':
elif bin_type in ['SEOBNRv2Peak', 'SEOBNRv4Peak']:
vals = pycbc.pnutils.get_freq(
'f' + bin_type,
data['mass1'],
data['mass2'],
data['spin1z'],
data['spin2z']
)
cached_values[bin_type] = vals
elif bin_type in _APPROXIMANT_DURATION_MAP:
vals = pycbc.pnutils.get_imr_duration(
data['mass1'][:], data['mass2'][:],
data['spin1z'][:], data['spin2z'][:],
data['f_lower'][:], approximant='SEOBNRv5_ROM')
data['mass1'],
data['mass2'],
data['spin1z'],
data['spin2z'],
data['f_lower'],
approximant=_APPROXIMANT_DURATION_MAP[bin_type]
)
cached_values[bin_type] = vals
else:
raise ValueError('Invalid bin type %s' % bin_type)

sub_locs = member_func(vals)
del vals
sub_locs = numpy.where(sub_locs)[0]
if locs is not None:
# find intersection of boundary conditions
Expand Down

0 comments on commit c1ae0d5

Please sign in to comment.