-
Notifications
You must be signed in to change notification settings - Fork 358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add acceleration option to JointPrimaryMarginalizedModel likelihood #4688
Changes from all commits
55e5541
a6ac76d
f4ed98d
41224cc
7109db9
0bbe7a4
6920a8c
55fba36
b993833
7099515
94ba798
6afbc4e
e20b6f5
25a4562
8fb82a1
757a30b
a2d64d1
8a9287e
5423d8c
0b9d44e
eeb8890
50e3599
dba5292
537256e
0af3fed
0a09b12
eb57268
6d856b3
075c39a
9ffbb70
e0f1ec4
6ce67c3
bf105a4
273264f
5226b7d
21fd035
ba3816d
28fc1b2
b06d32e
b4a47af
a5b6d8c
ca096ec
e8825be
be2b066
c03652e
02b6937
87f10cb
cbcd5a2
36af111
c084f4f
709b524
3a17bf5
3865d5e
0df23f2
da34461
f2b0798
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -607,26 +607,14 @@ def _loglikelihood(self): | |
|
||
|
||
class JointPrimaryMarginalizedModel(HierarchicalModel): | ||
""" Hierarchical heterodyne likelihood for coherent multiband | ||
parameter estimation which combines data from space-borne and | ||
ground-based GW detectors coherently. Currently, this only | ||
supports LISA as the space-borne GW detector. | ||
|
||
Sub models are treated as if the same GW source (such as a GW | ||
from stellar-mass BBH) is observed in different frequency bands by | ||
space-borne and ground-based GW detectors, then transform all | ||
the parameters into the same frame in the sub model level, use | ||
`HierarchicalModel` to get the joint likelihood, and marginalize | ||
over all the extrinsic parameters supported by `RelativeTimeDom` | ||
or its variants. Note that LISA submodel only supports the `Relative` | ||
for now, for ground-based detectors, please use `RelativeTimeDom` | ||
or its variants. | ||
|
||
Although this likelihood model is used for multiband parameter | ||
estimation, users can still use it for other purposes, such as | ||
GW + EM parameter estimation, in this case, please use `RelativeTimeDom` | ||
or its variants for the GW data, for the likelihood of EM data, | ||
there is no restrictions. | ||
"""This likelihood model can be used for cases when one of the submodels | ||
can be marginalized to accelerate the total likelihood. This likelihood | ||
model also allows for further acceleration of other models during | ||
marginalization, if some extrinsic parameters can be tightly constrained | ||
by the primary model. More specifically, such as the EM + GW parameter | ||
estimation, the sky localization can be well measured. For LISA + 3G | ||
multiband observation, SOBHB signals' (tc, ra, dec) can be tightly | ||
constrained by 3G network, so this model is also useful for this case. | ||
""" | ||
name = 'joint_primary_marginalized' | ||
|
||
|
@@ -640,6 +628,11 @@ def __init__(self, variable_params, submodels, **kwargs): | |
self.other_models.pop(kwargs['primary_lbl'][0]) | ||
self.other_models = list(self.other_models.values()) | ||
|
||
# determine whether to accelerate total_loglr | ||
from .tools import str_to_bool | ||
self.static_margin_params_in_other_models = \ | ||
str_to_bool(kwargs['static_margin_params_in_other_models'][0]) | ||
|
||
def write_metadata(self, fp, group=None): | ||
"""Adds metadata to the output files | ||
|
||
|
@@ -686,34 +679,43 @@ def total_loglr(self): | |
""" | ||
# calculate <d-h|d-h> = <h|h> - 2<h|d> + <d|d> up to a constant | ||
|
||
# note that for SOBHB signals, ground-based detectors dominant SNR | ||
# and accuracy of (tc, ra, dec) | ||
self.primary_model.return_sh_hh = True | ||
sh_primary, hh_primary = self.primary_model.loglr | ||
self.primary_model.return_sh_hh = False | ||
|
||
margin_names_vector = list( | ||
self.primary_model.marginalize_vector_params.keys()) | ||
if 'logw_partial' in margin_names_vector: | ||
margin_names_vector.remove('logw_partial') | ||
# set logr, otherwise it will store (sh, hh) | ||
setattr(self.primary_model._current_stats, 'loglr', | ||
self.primary_model.marginalize_loglr(sh_primary, hh_primary)) | ||
if isinstance(sh_primary, numpy.ndarray): | ||
nums = len(sh_primary) | ||
else: | ||
nums = 1 | ||
|
||
margin_params = {} | ||
nums = 1 | ||
for key, value in self.primary_model.current_params.items(): | ||
# add marginalize_vector_params | ||
if key in margin_names_vector: | ||
margin_params[key] = value | ||
if isinstance(value, numpy.ndarray): | ||
nums = len(value) | ||
# add distance if it has been marginalized, | ||
# use numpy array for it is just let it has the same | ||
# shape as marginalize_vector_params, here we assume | ||
# self.primary_model.current_params['distance'] is a number | ||
if self.primary_model.distance_marginalization: | ||
margin_params['distance'] = numpy.full( | ||
nums, self.primary_model.current_params['distance']) | ||
|
||
# add likelihood contribution from space-borne detectors, we | ||
if self.static_margin_params_in_other_models: | ||
# Due to the high precision of extrinsic parameters constrined | ||
# by the primary model, the mismatch of wavefroms in others by | ||
# varing those parameters is pretty small, so we can keep them | ||
# static to accelerate total_loglr. Here, we use matched-filering | ||
# SNR instead of lilkelihood, because luminosity distance and | ||
# inclination has a very strong degeneracy, change of inclination | ||
# will change best match distance, so change the amplitude of | ||
# waveform. Using SNR will cancel out the effect of amplitude.err | ||
i_max_extrinsic = numpy.argmax( | ||
numpy.abs(sh_primary) / hh_primary**0.5) | ||
for p in self.primary_model.marginalized_params_name: | ||
if isinstance(self.primary_model.current_params[p], | ||
numpy.ndarray): | ||
margin_params[p] = \ | ||
self.primary_model.current_params[p][i_max_extrinsic] | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WuShichao This logic should already take care of the distance case, except that later on you assume that if any parameter is a scalar they all are. That's the part you should stop assuming. Don't assume they are any particular mix of scalar or vector. |
||
margin_params[p] = self.primary_model.current_params[p] | ||
else: | ||
for key, value in self.primary_model.current_params.items(): | ||
# add marginalize_vector_params | ||
if key in self.primary_model.marginalized_params_name: | ||
margin_params[key] = value | ||
|
||
# add likelihood contribution from other_models, we | ||
# calculate sh/hh for each marginalized parameter point | ||
sh_others = numpy.full(nums, 0 + 0.0j) | ||
hh_others = numpy.zeros(nums) | ||
|
@@ -723,24 +725,46 @@ def total_loglr(self): | |
# not using self.primary_model.current_params, because others_model | ||
# may have its own static parameters | ||
current_params_other = other_model.current_params.copy() | ||
for i in range(nums): | ||
if not self.static_margin_params_in_other_models: | ||
for i in range(nums): | ||
current_params_other.update( | ||
{key: value[i] if isinstance(value, numpy.ndarray) | ||
else value for key, value in margin_params.items()}) | ||
other_model.update(**current_params_other) | ||
other_model.return_sh_hh = True | ||
sh_other, hh_other = other_model.loglr | ||
sh_others[i] += sh_other | ||
hh_others[i] += hh_other | ||
other_model.return_sh_hh = False | ||
# set logr, otherwise it will store (sh, hh) | ||
setattr(other_model._current_stats, 'loglr', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need to store this? It's not necessarily a problem, but it will slow down the code slightly (maybe not important at the moment). Why not think about why it was being set to a vector (and from where), do you even want this stored in the case of a submodel? Maybe the solution was simply not to store this when it's not actually a marginalized loglr anyway, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WuShichao OK, good. So now the question is what is the right thing to do in this case? |
||
other_model.marginalize_loglr(sh_other, hh_other)) | ||
else: | ||
# use one margin point set to approximate all the others | ||
current_params_other.update( | ||
{key: value[i] if isinstance(value, numpy.ndarray) else | ||
value for key, value in margin_params.items()}) | ||
{key: value[0] if isinstance(value, numpy.ndarray) | ||
else value for key, value in margin_params.items()}) | ||
other_model.update(**current_params_other) | ||
other_model.return_sh_hh = True | ||
sh_others[i], hh_others[i] = other_model.loglr | ||
sh_other, hh_other = other_model.loglr | ||
other_model.return_sh_hh = False | ||
# set logr, otherwise it will store (sh, hh) | ||
setattr(other_model._current_stats, 'loglr', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. It's not necessarily a problem as it might be useful to have the separate loglrs, but it's not clear that it will always make sense. |
||
other_model.marginalize_loglr(sh_other, hh_other)) | ||
sh_others += sh_other | ||
hh_others += hh_other | ||
|
||
if nums == 1: | ||
# the type of the original sh/hh_others are numpy.array, | ||
# might not the same as sh/hh_primary during reconstruct, | ||
# during reconstruct of distance, sh/hh_others need to be scalar | ||
sh_others = sh_others[0] | ||
hh_others = hh_others[0] | ||
sh_total = sh_primary + sh_others | ||
hh_total = hh_primary + hh_others | ||
|
||
# calculate marginalize_vector_weights | ||
self.primary_model.marginalize_vector_weights = \ | ||
- numpy.log(self.primary_model.vsamples) | ||
loglr = self.primary_model.marginalize_loglr(sh_total, hh_total) | ||
|
||
return loglr | ||
|
||
def others_lognl(self): | ||
|
@@ -805,6 +829,10 @@ def from_config(cls, cp, **kwargs): | |
sparam_map = map_params(hpiter(cp.options('static_params'), | ||
submodel_lbls)) | ||
|
||
# get the acceleration label | ||
kwargs['static_margin_params_in_other_models'] = shlex.split( | ||
cp.get('model', 'static_margin_params_in_other_models')) | ||
|
||
# we'll need any waveform transforms for the initializing sub-models, | ||
# as the underlying models will receive the output of those transforms | ||
|
||
|
@@ -856,32 +884,28 @@ def from_config(cls, cp, **kwargs): | |
cp.get('static_params', param.fullname)) | ||
|
||
# set the variable params: different from the standard | ||
# hierarchical model, in this multiband model, all sub-models | ||
# has the same variable parameters, so we don't need to worry | ||
# about the unique variable issue. Besides, the primary model | ||
# needs to do marginalization, so we must set variable_params | ||
# and prior section before initializing it. | ||
# hierarchical model, in this JointPrimaryMarginalizedModel model, | ||
# all sub-models has the same variable parameters, so we don't | ||
# need to worry about the unique variable issue. Besides, | ||
# the primary model needs to do marginalization, so we must set | ||
# variable_params and prior section before initializing it. | ||
|
||
subcp.add_section('variable_params') | ||
for param in vparam_map[lbl]: | ||
if lbl in kwargs['primary_lbl']: | ||
# set variable_params for the primary model | ||
subcp.set('variable_params', param.subname, | ||
cp.get('variable_params', param.fullname)) | ||
else: | ||
# all variable_params in other models will come | ||
# from the primary model during sampling | ||
subcp.set('static_params', param.subname, 'REPLACE') | ||
|
||
for section in cp.sections(): | ||
# the primary model needs prior of marginlized parameters | ||
if 'prior-' in section and lbl in kwargs['primary_lbl']: | ||
prior_section = '%s' % section | ||
subcp[prior_section] = cp[prior_section] | ||
# if `waveform_transforms` has a prefix, | ||
# add it into sub-models' config | ||
elif '%s_waveform_transforms' % lbl in section: | ||
transforms_section = '%s' % section | ||
subcp[transforms_section] = cp[transforms_section] | ||
else: | ||
pass | ||
|
||
# similar to the standard hierarchical model, | ||
# add the outputs from the waveform transforms if sub-model | ||
|
@@ -918,15 +942,8 @@ def from_config(cls, cp, **kwargs): | |
# here we ignore `coa_phase`, because if it's been marginalized, | ||
# it will not be listed in `variable_params` and `prior` sections | ||
primary_model = submodels[kwargs['primary_lbl'][0]] | ||
marginalized_params = primary_model.marginalize_vector_params.copy() | ||
if 'logw_partial' in marginalized_params: | ||
marginalized_params.pop('logw_partial') | ||
marginalized_params = list(marginalized_params.keys()) | ||
else: | ||
marginalized_params = [] | ||
# this may also include 'f_ref', 'f_lower', 'approximant', | ||
# but doesn't matter | ||
marginalized_params += list(primary_model.static_params.keys()) | ||
marginalized_params = primary_model.marginalized_params_name.copy() | ||
|
||
for p in primary_model.static_params.keys(): | ||
p_full = '%s__%s' % (kwargs['primary_lbl'][0], p) | ||
if p_full not in cp['static_params']: | ||
|
@@ -940,6 +957,10 @@ def from_config(cls, cp, **kwargs): | |
cp['variable_params'].pop(p) | ||
cp.pop(section) | ||
|
||
# save the vitual config file to disk for later check | ||
with open('internal_top.ini', 'w', encoding='utf-8') as file: | ||
cp.write(file) | ||
|
||
# now load the model | ||
logging.info("Loading joint_primary_marginalized model") | ||
return super(HierarchicalModel, cls).from_config( | ||
|
@@ -956,6 +977,12 @@ def reconstruct(self, rec=None, seed=None): | |
rec = {} | ||
|
||
def get_loglr(): | ||
# make sure waveform transforms have been applied in | ||
# the top-level model | ||
if self.waveform_transforms is not None: | ||
self._current_params = transforms.apply_transforms( | ||
self._current_params, self.waveform_transforms, | ||
inverse=False) | ||
self.update_all_models(**rec) | ||
return self.total_loglr() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -596,10 +596,11 @@ def _loglr(self): | |
filt += filter_i | ||
norm += norm_i | ||
|
||
loglr = self.marginalize_loglr(filt, norm) | ||
if self.return_sh_hh: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this if statement? Shouldn't the existing flag already used for demarginalization take care of this? E.g. why not use the reconstruct_phase flag? https://github.com/gwastro/pycbc/blob/master/pycbc/inference/models/tools.py#L241 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ahnitz This is used here https://github.com/WuShichao/pycbc/blob/accelerate_multiband/pycbc/inference/models/hierarchical.py#L752 We need to get LISA's sh and hh. |
||
results = (filt, norm) | ||
else: | ||
results = self.marginalize_loglr(filt, norm) | ||
results = loglr | ||
return results | ||
|
||
def write_metadata(self, fp, group=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still has the old name in the config file. Also, why not just do
self.static_margin_params = 'static_margin_params' in kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ahnitz OK, I have updated.