Skip to content

Commit

Permalink
add acceleration option to JointPrimaryMarginalizedModel likelihood (g…
Browse files Browse the repository at this point in the history
…wastro#4688)

* Update hierarchical.py

* Update hierarchical.py

* Update hierarchical.py

* Update hierarchical.py

* Update hierarchical.py

* fix cc issues

* Update hierarchical.py

* Update relbin.py

* add complex phase correction for sh_others

* Update hierarchical.py

* Update relbin.py

* fix cc issues

* make code more general

* update

* fix

* rename

* update

* WIP

* fix a bug in frame transform

* fix overwritten issues

* update

* update

* fix reconstruct

* make this PR general

* update

* update

* fix cc issues

* rename

* rename

* add multiband description

* fix

* add comments

* fix hdf's config

* fix

* fix

* fix

* fix

* remove print

* update for Alex's comments

* wip

* update

* fix

* update

* seems work

* fix CC issue

* fix

* fix demargin
  • Loading branch information
WuShichao authored and prayush committed Nov 21, 2024
1 parent b762cd8 commit 5774fb1
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 75 deletions.
8 changes: 6 additions & 2 deletions bin/inference/pycbc_inference
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ fft.from_cli(opts)
with ctx:

# read configuration file
cp = configuration.WorkflowConfigParser.from_cli(opts)
cp_original = configuration.WorkflowConfigParser.from_cli(opts)
# some models will interally modify original cp for sampling,
# such as joint_primary_marginalized, we need to save original
# and let modify the copied one
cp = cp_original.__deepcopy__(cp_original)

# create an empty checkpoint file, if needed
condor_ckpt = cp.has_option('sampler', 'checkpoint-signal')
Expand Down Expand Up @@ -138,7 +142,7 @@ with ctx:
if pool.is_main_process():
for fn in [sampler.checkpoint_file, sampler.backup_file]:
with loadfile(fn, 'a') as fp:
fp.write_config_file(cp)
fp.write_config_file(cp_original)

# Run the sampler
sampler.run()
Expand Down
169 changes: 98 additions & 71 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,26 +607,14 @@ def _loglikelihood(self):


class JointPrimaryMarginalizedModel(HierarchicalModel):
""" Hierarchical heterodyne likelihood for coherent multiband
parameter estimation which combines data from space-borne and
ground-based GW detectors coherently. Currently, this only
supports LISA as the space-borne GW detector.
Sub models are treated as if the same GW source (such as a GW
from stellar-mass BBH) is observed in different frequency bands by
space-borne and ground-based GW detectors, then transform all
the parameters into the same frame in the sub model level, use
`HierarchicalModel` to get the joint likelihood, and marginalize
over all the extrinsic parameters supported by `RelativeTimeDom`
or its variants. Note that LISA submodel only supports the `Relative`
for now, for ground-based detectors, please use `RelativeTimeDom`
or its variants.
Although this likelihood model is used for multiband parameter
estimation, users can still use it for other purposes, such as
GW + EM parameter estimation, in this case, please use `RelativeTimeDom`
or its variants for the GW data, for the likelihood of EM data,
there is no restrictions.
"""This likelihood model can be used for cases when one of the submodels
can be marginalized to accelerate the total likelihood. This likelihood
model also allows for further acceleration of other models during
marginalization, if some extrinsic parameters can be tightly constrained
by the primary model. More specifically, such as the EM + GW parameter
estimation, the sky localization can be well measured. For LISA + 3G
multiband observation, SOBHB signals' (tc, ra, dec) can be tightly
constrained by 3G network, so this model is also useful for this case.
"""
name = 'joint_primary_marginalized'

Expand All @@ -640,6 +628,11 @@ def __init__(self, variable_params, submodels, **kwargs):
self.other_models.pop(kwargs['primary_lbl'][0])
self.other_models = list(self.other_models.values())

# determine whether to accelerate total_loglr
from .tools import str_to_bool
self.static_margin_params_in_other_models = \
str_to_bool(kwargs['static_margin_params_in_other_models'][0])

def write_metadata(self, fp, group=None):
"""Adds metadata to the output files
Expand Down Expand Up @@ -686,34 +679,43 @@ def total_loglr(self):
"""
# calculate <d-h|d-h> = <h|h> - 2<h|d> + <d|d> up to a constant

# note that for SOBHB signals, ground-based detectors dominant SNR
# and accuracy of (tc, ra, dec)
self.primary_model.return_sh_hh = True
sh_primary, hh_primary = self.primary_model.loglr
self.primary_model.return_sh_hh = False

margin_names_vector = list(
self.primary_model.marginalize_vector_params.keys())
if 'logw_partial' in margin_names_vector:
margin_names_vector.remove('logw_partial')
# set logr, otherwise it will store (sh, hh)
setattr(self.primary_model._current_stats, 'loglr',
self.primary_model.marginalize_loglr(sh_primary, hh_primary))
if isinstance(sh_primary, numpy.ndarray):
nums = len(sh_primary)
else:
nums = 1

margin_params = {}
nums = 1
for key, value in self.primary_model.current_params.items():
# add marginalize_vector_params
if key in margin_names_vector:
margin_params[key] = value
if isinstance(value, numpy.ndarray):
nums = len(value)
# add distance if it has been marginalized,
# use numpy array for it is just let it has the same
# shape as marginalize_vector_params, here we assume
# self.primary_model.current_params['distance'] is a number
if self.primary_model.distance_marginalization:
margin_params['distance'] = numpy.full(
nums, self.primary_model.current_params['distance'])

# add likelihood contribution from space-borne detectors, we
if self.static_margin_params_in_other_models:
# Due to the high precision of extrinsic parameters constrined
# by the primary model, the mismatch of wavefroms in others by
# varing those parameters is pretty small, so we can keep them
# static to accelerate total_loglr. Here, we use matched-filering
# SNR instead of lilkelihood, because luminosity distance and
# inclination has a very strong degeneracy, change of inclination
# will change best match distance, so change the amplitude of
# waveform. Using SNR will cancel out the effect of amplitude.err
i_max_extrinsic = numpy.argmax(
numpy.abs(sh_primary) / hh_primary**0.5)
for p in self.primary_model.marginalized_params_name:
if isinstance(self.primary_model.current_params[p],
numpy.ndarray):
margin_params[p] = \
self.primary_model.current_params[p][i_max_extrinsic]
else:
margin_params[p] = self.primary_model.current_params[p]
else:
for key, value in self.primary_model.current_params.items():
# add marginalize_vector_params
if key in self.primary_model.marginalized_params_name:
margin_params[key] = value

# add likelihood contribution from other_models, we
# calculate sh/hh for each marginalized parameter point
sh_others = numpy.full(nums, 0 + 0.0j)
hh_others = numpy.zeros(nums)
Expand All @@ -723,24 +725,46 @@ def total_loglr(self):
# not using self.primary_model.current_params, because others_model
# may have its own static parameters
current_params_other = other_model.current_params.copy()
for i in range(nums):
if not self.static_margin_params_in_other_models:
for i in range(nums):
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
other_model.return_sh_hh = True
sh_other, hh_other = other_model.loglr
sh_others[i] += sh_other
hh_others[i] += hh_other
other_model.return_sh_hh = False
# set logr, otherwise it will store (sh, hh)
setattr(other_model._current_stats, 'loglr',
other_model.marginalize_loglr(sh_other, hh_other))
else:
# use one margin point set to approximate all the others
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray) else
value for key, value in margin_params.items()})
{key: value[0] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
other_model.return_sh_hh = True
sh_others[i], hh_others[i] = other_model.loglr
sh_other, hh_other = other_model.loglr
other_model.return_sh_hh = False
# set logr, otherwise it will store (sh, hh)
setattr(other_model._current_stats, 'loglr',
other_model.marginalize_loglr(sh_other, hh_other))
sh_others += sh_other
hh_others += hh_other

if nums == 1:
# the type of the original sh/hh_others are numpy.array,
# might not the same as sh/hh_primary during reconstruct,
# during reconstruct of distance, sh/hh_others need to be scalar
sh_others = sh_others[0]
hh_others = hh_others[0]
sh_total = sh_primary + sh_others
hh_total = hh_primary + hh_others

# calculate marginalize_vector_weights
self.primary_model.marginalize_vector_weights = \
- numpy.log(self.primary_model.vsamples)
loglr = self.primary_model.marginalize_loglr(sh_total, hh_total)

return loglr

def others_lognl(self):
Expand Down Expand Up @@ -805,6 +829,10 @@ def from_config(cls, cp, **kwargs):
sparam_map = map_params(hpiter(cp.options('static_params'),
submodel_lbls))

# get the acceleration label
kwargs['static_margin_params_in_other_models'] = shlex.split(
cp.get('model', 'static_margin_params_in_other_models'))

# we'll need any waveform transforms for the initializing sub-models,
# as the underlying models will receive the output of those transforms

Expand Down Expand Up @@ -856,32 +884,28 @@ def from_config(cls, cp, **kwargs):
cp.get('static_params', param.fullname))

# set the variable params: different from the standard
# hierarchical model, in this multiband model, all sub-models
# has the same variable parameters, so we don't need to worry
# about the unique variable issue. Besides, the primary model
# needs to do marginalization, so we must set variable_params
# and prior section before initializing it.
# hierarchical model, in this JointPrimaryMarginalizedModel model,
# all sub-models has the same variable parameters, so we don't
# need to worry about the unique variable issue. Besides,
# the primary model needs to do marginalization, so we must set
# variable_params and prior section before initializing it.

subcp.add_section('variable_params')
for param in vparam_map[lbl]:
if lbl in kwargs['primary_lbl']:
# set variable_params for the primary model
subcp.set('variable_params', param.subname,
cp.get('variable_params', param.fullname))
else:
# all variable_params in other models will come
# from the primary model during sampling
subcp.set('static_params', param.subname, 'REPLACE')

for section in cp.sections():
# the primary model needs prior of marginlized parameters
if 'prior-' in section and lbl in kwargs['primary_lbl']:
prior_section = '%s' % section
subcp[prior_section] = cp[prior_section]
# if `waveform_transforms` has a prefix,
# add it into sub-models' config
elif '%s_waveform_transforms' % lbl in section:
transforms_section = '%s' % section
subcp[transforms_section] = cp[transforms_section]
else:
pass

# similar to the standard hierarchical model,
# add the outputs from the waveform transforms if sub-model
Expand Down Expand Up @@ -918,15 +942,8 @@ def from_config(cls, cp, **kwargs):
# here we ignore `coa_phase`, because if it's been marginalized,
# it will not be listed in `variable_params` and `prior` sections
primary_model = submodels[kwargs['primary_lbl'][0]]
marginalized_params = primary_model.marginalize_vector_params.copy()
if 'logw_partial' in marginalized_params:
marginalized_params.pop('logw_partial')
marginalized_params = list(marginalized_params.keys())
else:
marginalized_params = []
# this may also include 'f_ref', 'f_lower', 'approximant',
# but doesn't matter
marginalized_params += list(primary_model.static_params.keys())
marginalized_params = primary_model.marginalized_params_name.copy()

for p in primary_model.static_params.keys():
p_full = '%s__%s' % (kwargs['primary_lbl'][0], p)
if p_full not in cp['static_params']:
Expand All @@ -940,6 +957,10 @@ def from_config(cls, cp, **kwargs):
cp['variable_params'].pop(p)
cp.pop(section)

# save the vitual config file to disk for later check
with open('internal_top.ini', 'w', encoding='utf-8') as file:
cp.write(file)

# now load the model
logging.info("Loading joint_primary_marginalized model")
return super(HierarchicalModel, cls).from_config(
Expand All @@ -956,6 +977,12 @@ def reconstruct(self, rec=None, seed=None):
rec = {}

def get_loglr():
# make sure waveform transforms have been applied in
# the top-level model
if self.waveform_transforms is not None:
self._current_params = transforms.apply_transforms(
self._current_params, self.waveform_transforms,
inverse=False)
self.update_all_models(**rec)
return self.total_loglr()

Expand Down
9 changes: 8 additions & 1 deletion pycbc/inference/models/marginalized_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def __init__(self, variable_params,
sample_rate=None,
**kwargs):

# the flag used in `_loglr`
self.return_sh_hh = False
self.sample_rate = float(sample_rate)
self.kwargs = kwargs
variable_params, kwargs = self.setup_marginalization(
Expand Down Expand Up @@ -259,6 +261,7 @@ def _nowaveform_loglr(self):

def _loglr(self):
r"""Computes the log likelihood ratio,
or inner product <s|h> and <h|h> if `self.return_sh_hh` is True.
.. math::
Expand Down Expand Up @@ -369,7 +372,11 @@ def _loglr(self):
hh_total += hh

loglr = self.marginalize_loglr(sh_total, hh_total)
return loglr
if self.return_sh_hh:
results = (sh_total, hh_total)
else:
results = loglr
return results


class MarginalizedPolarization(DistMarg, BaseGaussianNoise):
Expand Down
3 changes: 2 additions & 1 deletion pycbc/inference/models/relbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,11 @@ def _loglr(self):
filt += filter_i
norm += norm_i

loglr = self.marginalize_loglr(filt, norm)
if self.return_sh_hh:
results = (filt, norm)
else:
results = self.marginalize_loglr(filt, norm)
results = loglr
return results

def write_metadata(self, fp, group=None):
Expand Down
8 changes: 8 additions & 0 deletions pycbc/inference/models/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def pop_prior(param):
self.distance_interpolator = i

kwargs['static_params']['distance'] = dist_ref

# Save marginalized parameters' name into one place,
# coa_phase will be a static param if been marginalized
if marginalize_distance:
self.marginalized_params_name =\
list(self.marginalize_vector_params.keys()) +\
[marginalize_distance_param]

return variable_params, kwargs

def reset_vector_params(self):
Expand Down

0 comments on commit 5774fb1

Please sign in to comment.