From 5774fb1ffa2380aa5d9885b151bd045770469eba Mon Sep 17 00:00:00 2001 From: Shichao Wu Date: Mon, 9 Sep 2024 18:14:56 +0200 Subject: [PATCH] add acceleration option to JointPrimaryMarginalizedModel likelihood (#4688) * Update hierarchical.py * Update hierarchical.py * Update hierarchical.py * Update hierarchical.py * Update hierarchical.py * fix cc issues * Update hierarchical.py * Update relbin.py * add complex phase correction for sh_others * Update hierarchical.py * Update relbin.py * fix cc issues * make code more general * update * fix * rename * update * WIP * fix a bug in frame transform * fix overwritten issues * update * update * fix reconstruct * make this PR general * update * update * fix cc issues * rename * rename * add multiband description * fix * add comments * fix hdf's config * fix * fix * fix * fix * remove print * update for Alex's comments * wip * update * fix * update * seems work * fix CC issue * fix * fix demargin --- bin/inference/pycbc_inference | 8 +- pycbc/inference/models/hierarchical.py | 169 ++++++++++-------- .../models/marginalized_gaussian_noise.py | 9 +- pycbc/inference/models/relbin.py | 3 +- pycbc/inference/models/tools.py | 8 + 5 files changed, 122 insertions(+), 75 deletions(-) diff --git a/bin/inference/pycbc_inference b/bin/inference/pycbc_inference index 701663a3550..442172eb967 100644 --- a/bin/inference/pycbc_inference +++ b/bin/inference/pycbc_inference @@ -108,7 +108,11 @@ fft.from_cli(opts) with ctx: # read configuration file - cp = configuration.WorkflowConfigParser.from_cli(opts) + cp_original = configuration.WorkflowConfigParser.from_cli(opts) + # some models will interally modify original cp for sampling, + # such as joint_primary_marginalized, we need to save original + # and let modify the copied one + cp = cp_original.__deepcopy__(cp_original) # create an empty checkpoint file, if needed condor_ckpt = cp.has_option('sampler', 'checkpoint-signal') @@ -138,7 +142,7 @@ with ctx: if pool.is_main_process(): for fn in [sampler.checkpoint_file, sampler.backup_file]: with loadfile(fn, 'a') as fp: - fp.write_config_file(cp) + fp.write_config_file(cp_original) # Run the sampler sampler.run() diff --git a/pycbc/inference/models/hierarchical.py b/pycbc/inference/models/hierarchical.py index 81db28fa79f..d0507efa359 100644 --- a/pycbc/inference/models/hierarchical.py +++ b/pycbc/inference/models/hierarchical.py @@ -607,26 +607,14 @@ def _loglikelihood(self): class JointPrimaryMarginalizedModel(HierarchicalModel): - """ Hierarchical heterodyne likelihood for coherent multiband - parameter estimation which combines data from space-borne and - ground-based GW detectors coherently. Currently, this only - supports LISA as the space-borne GW detector. - - Sub models are treated as if the same GW source (such as a GW - from stellar-mass BBH) is observed in different frequency bands by - space-borne and ground-based GW detectors, then transform all - the parameters into the same frame in the sub model level, use - `HierarchicalModel` to get the joint likelihood, and marginalize - over all the extrinsic parameters supported by `RelativeTimeDom` - or its variants. Note that LISA submodel only supports the `Relative` - for now, for ground-based detectors, please use `RelativeTimeDom` - or its variants. - - Although this likelihood model is used for multiband parameter - estimation, users can still use it for other purposes, such as - GW + EM parameter estimation, in this case, please use `RelativeTimeDom` - or its variants for the GW data, for the likelihood of EM data, - there is no restrictions. + """This likelihood model can be used for cases when one of the submodels + can be marginalized to accelerate the total likelihood. This likelihood + model also allows for further acceleration of other models during + marginalization, if some extrinsic parameters can be tightly constrained + by the primary model. More specifically, such as the EM + GW parameter + estimation, the sky localization can be well measured. For LISA + 3G + multiband observation, SOBHB signals' (tc, ra, dec) can be tightly + constrained by 3G network, so this model is also useful for this case. """ name = 'joint_primary_marginalized' @@ -640,6 +628,11 @@ def __init__(self, variable_params, submodels, **kwargs): self.other_models.pop(kwargs['primary_lbl'][0]) self.other_models = list(self.other_models.values()) + # determine whether to accelerate total_loglr + from .tools import str_to_bool + self.static_margin_params_in_other_models = \ + str_to_bool(kwargs['static_margin_params_in_other_models'][0]) + def write_metadata(self, fp, group=None): """Adds metadata to the output files @@ -686,34 +679,43 @@ def total_loglr(self): """ # calculate = - 2 + up to a constant - # note that for SOBHB signals, ground-based detectors dominant SNR - # and accuracy of (tc, ra, dec) self.primary_model.return_sh_hh = True sh_primary, hh_primary = self.primary_model.loglr self.primary_model.return_sh_hh = False - - margin_names_vector = list( - self.primary_model.marginalize_vector_params.keys()) - if 'logw_partial' in margin_names_vector: - margin_names_vector.remove('logw_partial') + # set logr, otherwise it will store (sh, hh) + setattr(self.primary_model._current_stats, 'loglr', + self.primary_model.marginalize_loglr(sh_primary, hh_primary)) + if isinstance(sh_primary, numpy.ndarray): + nums = len(sh_primary) + else: + nums = 1 margin_params = {} - nums = 1 - for key, value in self.primary_model.current_params.items(): - # add marginalize_vector_params - if key in margin_names_vector: - margin_params[key] = value - if isinstance(value, numpy.ndarray): - nums = len(value) - # add distance if it has been marginalized, - # use numpy array for it is just let it has the same - # shape as marginalize_vector_params, here we assume - # self.primary_model.current_params['distance'] is a number - if self.primary_model.distance_marginalization: - margin_params['distance'] = numpy.full( - nums, self.primary_model.current_params['distance']) - - # add likelihood contribution from space-borne detectors, we + if self.static_margin_params_in_other_models: + # Due to the high precision of extrinsic parameters constrined + # by the primary model, the mismatch of wavefroms in others by + # varing those parameters is pretty small, so we can keep them + # static to accelerate total_loglr. Here, we use matched-filering + # SNR instead of lilkelihood, because luminosity distance and + # inclination has a very strong degeneracy, change of inclination + # will change best match distance, so change the amplitude of + # waveform. Using SNR will cancel out the effect of amplitude.err + i_max_extrinsic = numpy.argmax( + numpy.abs(sh_primary) / hh_primary**0.5) + for p in self.primary_model.marginalized_params_name: + if isinstance(self.primary_model.current_params[p], + numpy.ndarray): + margin_params[p] = \ + self.primary_model.current_params[p][i_max_extrinsic] + else: + margin_params[p] = self.primary_model.current_params[p] + else: + for key, value in self.primary_model.current_params.items(): + # add marginalize_vector_params + if key in self.primary_model.marginalized_params_name: + margin_params[key] = value + + # add likelihood contribution from other_models, we # calculate sh/hh for each marginalized parameter point sh_others = numpy.full(nums, 0 + 0.0j) hh_others = numpy.zeros(nums) @@ -723,24 +725,46 @@ def total_loglr(self): # not using self.primary_model.current_params, because others_model # may have its own static parameters current_params_other = other_model.current_params.copy() - for i in range(nums): + if not self.static_margin_params_in_other_models: + for i in range(nums): + current_params_other.update( + {key: value[i] if isinstance(value, numpy.ndarray) + else value for key, value in margin_params.items()}) + other_model.update(**current_params_other) + other_model.return_sh_hh = True + sh_other, hh_other = other_model.loglr + sh_others[i] += sh_other + hh_others[i] += hh_other + other_model.return_sh_hh = False + # set logr, otherwise it will store (sh, hh) + setattr(other_model._current_stats, 'loglr', + other_model.marginalize_loglr(sh_other, hh_other)) + else: + # use one margin point set to approximate all the others current_params_other.update( - {key: value[i] if isinstance(value, numpy.ndarray) else - value for key, value in margin_params.items()}) + {key: value[0] if isinstance(value, numpy.ndarray) + else value for key, value in margin_params.items()}) other_model.update(**current_params_other) other_model.return_sh_hh = True - sh_others[i], hh_others[i] = other_model.loglr + sh_other, hh_other = other_model.loglr other_model.return_sh_hh = False + # set logr, otherwise it will store (sh, hh) + setattr(other_model._current_stats, 'loglr', + other_model.marginalize_loglr(sh_other, hh_other)) + sh_others += sh_other + hh_others += hh_other if nums == 1: + # the type of the original sh/hh_others are numpy.array, + # might not the same as sh/hh_primary during reconstruct, + # during reconstruct of distance, sh/hh_others need to be scalar sh_others = sh_others[0] + hh_others = hh_others[0] sh_total = sh_primary + sh_others hh_total = hh_primary + hh_others - # calculate marginalize_vector_weights - self.primary_model.marginalize_vector_weights = \ - - numpy.log(self.primary_model.vsamples) loglr = self.primary_model.marginalize_loglr(sh_total, hh_total) + return loglr def others_lognl(self): @@ -805,6 +829,10 @@ def from_config(cls, cp, **kwargs): sparam_map = map_params(hpiter(cp.options('static_params'), submodel_lbls)) + # get the acceleration label + kwargs['static_margin_params_in_other_models'] = shlex.split( + cp.get('model', 'static_margin_params_in_other_models')) + # we'll need any waveform transforms for the initializing sub-models, # as the underlying models will receive the output of those transforms @@ -856,18 +884,21 @@ def from_config(cls, cp, **kwargs): cp.get('static_params', param.fullname)) # set the variable params: different from the standard - # hierarchical model, in this multiband model, all sub-models - # has the same variable parameters, so we don't need to worry - # about the unique variable issue. Besides, the primary model - # needs to do marginalization, so we must set variable_params - # and prior section before initializing it. + # hierarchical model, in this JointPrimaryMarginalizedModel model, + # all sub-models has the same variable parameters, so we don't + # need to worry about the unique variable issue. Besides, + # the primary model needs to do marginalization, so we must set + # variable_params and prior section before initializing it. subcp.add_section('variable_params') for param in vparam_map[lbl]: if lbl in kwargs['primary_lbl']: + # set variable_params for the primary model subcp.set('variable_params', param.subname, cp.get('variable_params', param.fullname)) else: + # all variable_params in other models will come + # from the primary model during sampling subcp.set('static_params', param.subname, 'REPLACE') for section in cp.sections(): @@ -875,13 +906,6 @@ def from_config(cls, cp, **kwargs): if 'prior-' in section and lbl in kwargs['primary_lbl']: prior_section = '%s' % section subcp[prior_section] = cp[prior_section] - # if `waveform_transforms` has a prefix, - # add it into sub-models' config - elif '%s_waveform_transforms' % lbl in section: - transforms_section = '%s' % section - subcp[transforms_section] = cp[transforms_section] - else: - pass # similar to the standard hierarchical model, # add the outputs from the waveform transforms if sub-model @@ -918,15 +942,8 @@ def from_config(cls, cp, **kwargs): # here we ignore `coa_phase`, because if it's been marginalized, # it will not be listed in `variable_params` and `prior` sections primary_model = submodels[kwargs['primary_lbl'][0]] - marginalized_params = primary_model.marginalize_vector_params.copy() - if 'logw_partial' in marginalized_params: - marginalized_params.pop('logw_partial') - marginalized_params = list(marginalized_params.keys()) - else: - marginalized_params = [] - # this may also include 'f_ref', 'f_lower', 'approximant', - # but doesn't matter - marginalized_params += list(primary_model.static_params.keys()) + marginalized_params = primary_model.marginalized_params_name.copy() + for p in primary_model.static_params.keys(): p_full = '%s__%s' % (kwargs['primary_lbl'][0], p) if p_full not in cp['static_params']: @@ -940,6 +957,10 @@ def from_config(cls, cp, **kwargs): cp['variable_params'].pop(p) cp.pop(section) + # save the vitual config file to disk for later check + with open('internal_top.ini', 'w', encoding='utf-8') as file: + cp.write(file) + # now load the model logging.info("Loading joint_primary_marginalized model") return super(HierarchicalModel, cls).from_config( @@ -956,6 +977,12 @@ def reconstruct(self, rec=None, seed=None): rec = {} def get_loglr(): + # make sure waveform transforms have been applied in + # the top-level model + if self.waveform_transforms is not None: + self._current_params = transforms.apply_transforms( + self._current_params, self.waveform_transforms, + inverse=False) self.update_all_models(**rec) return self.total_loglr() diff --git a/pycbc/inference/models/marginalized_gaussian_noise.py b/pycbc/inference/models/marginalized_gaussian_noise.py index 9052a5018ed..05a402aa8cf 100644 --- a/pycbc/inference/models/marginalized_gaussian_noise.py +++ b/pycbc/inference/models/marginalized_gaussian_noise.py @@ -211,6 +211,8 @@ def __init__(self, variable_params, sample_rate=None, **kwargs): + # the flag used in `_loglr` + self.return_sh_hh = False self.sample_rate = float(sample_rate) self.kwargs = kwargs variable_params, kwargs = self.setup_marginalization( @@ -259,6 +261,7 @@ def _nowaveform_loglr(self): def _loglr(self): r"""Computes the log likelihood ratio, + or inner product and if `self.return_sh_hh` is True. .. math:: @@ -369,7 +372,11 @@ def _loglr(self): hh_total += hh loglr = self.marginalize_loglr(sh_total, hh_total) - return loglr + if self.return_sh_hh: + results = (sh_total, hh_total) + else: + results = loglr + return results class MarginalizedPolarization(DistMarg, BaseGaussianNoise): diff --git a/pycbc/inference/models/relbin.py b/pycbc/inference/models/relbin.py index 4f574992fa3..8c6be79a1ec 100644 --- a/pycbc/inference/models/relbin.py +++ b/pycbc/inference/models/relbin.py @@ -596,10 +596,11 @@ def _loglr(self): filt += filter_i norm += norm_i + loglr = self.marginalize_loglr(filt, norm) if self.return_sh_hh: results = (filt, norm) else: - results = self.marginalize_loglr(filt, norm) + results = loglr return results def write_metadata(self, fp, group=None): diff --git a/pycbc/inference/models/tools.py b/pycbc/inference/models/tools.py index 044c716440f..325a8913658 100644 --- a/pycbc/inference/models/tools.py +++ b/pycbc/inference/models/tools.py @@ -228,6 +228,14 @@ def pop_prior(param): self.distance_interpolator = i kwargs['static_params']['distance'] = dist_ref + + # Save marginalized parameters' name into one place, + # coa_phase will be a static param if been marginalized + if marginalize_distance: + self.marginalized_params_name =\ + list(self.marginalize_vector_params.keys()) +\ + [marginalize_distance_param] + return variable_params, kwargs def reset_vector_params(self):