Skip to content

Commit

Permalink
emcee solution plotting: support for passing c for trace/lnprobability
Browse files Browse the repository at this point in the history
* if style='lnprobability', c can be a twig pointing to any of the sampled parameters (whether within adopt_parameters or not)
* if style='trace', c can be a twig or 'lnprobability'
* NOTE: adding color to lineplots with many walkers/iterations can become expensive
  • Loading branch information
kecnry committed May 11, 2021
1 parent 05dbb0b commit f313775
Showing 1 changed file with 63 additions and 6 deletions.
69 changes: 63 additions & 6 deletions phoebe/parameters/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@
'vw': 'vws', 'nu': 'nus', 'nv': 'nvs', 'nw': 'nws',
'cosbeta': 'cosbetas', 'logg': 'loggs', 'teff': 'teffs',
'r': 'rs', 'rproj': 'rprojs', 'mu': 'mus',
'visibility': 'visibilities'}
'visibility': 'visibilities',
'lnprobability': 'lnprobabilities',
'lnlikelihood': 'lnlikelihoods'}
_plural_to_singular = {v:k for k,v in _singular_to_plural.items()}

def _singular_to_plural_get(k):
Expand Down Expand Up @@ -4291,7 +4293,10 @@ def _kwargs_fill_dimension(kwargs, direction, ps):

psff = psf.filter(twig=current_value, **_skip_filter_checks)
if len(psff)==1:
array_value = psff.get_quantity(**_skip_filter_checks)
if hasattr(psff.get_parameter(**_skip_filter_checks), 'get_quantity'):
array_value = psff.get_quantity(**_skip_filter_checks)
else:
array_value = psff.get_value(**_skip_filter_checks)
elif len(psff.times) > 1 and psff.get_value(time=psff.times[0], **_skip_filter_checks):
# then we'll assume we have something like volume vs times. If not, then there may be a length mismatch issue later
unit = psff.get_quantity(time=psff.times[0], **_skip_filter_checks).unit
Expand Down Expand Up @@ -4947,18 +4952,43 @@ def _phase_wrap(phase):
lnprobabilities_proc = _deepcopy(lnprobabilities_proc)
lnprobabilities_proc[lnprobabilities_proc < lnprob_cutoff] = np.nan

for lnp in lnprobabilities_proc.T:
c = kwargs.get('c', None)
if c is not None:
fitted_uniqueids = self._bundle.get_value(qualifier='fitted_uniqueids', context='solution', solution=ps.solution, **_skip_filter_checks)
fitted_ps = self._bundle.filter(uniqueid=list(fitted_uniqueids), **_skip_filter_checks)
_, samples_proc_all = _helpers.process_mcmc_chains(lnprobabilities, samples, burnin, thin, -np.inf, flatten=False)

kwargs_orig = _deepcopy(kwargs)
for walker_ind, lnp in enumerate(lnprobabilities_proc.T):
if not np.any(np.isfinite(lnp)):
continue

if np.all(np.isnan(lnp)):
continue

kwargs = _deepcopy(kwargs)
kwargs = _deepcopy(kwargs_orig)
kwargs['x'] = np.arange(len(lnp), dtype=float)*thin+burnin
kwargs['xlabel'] = 'iteration (burnin={}, thin={})'.format(burnin, thin)
kwargs['y'] = lnp
kwargs['ylabel'] = 'lnprobability' if lnprob_cutoff==-np.inf else 'lnprobability (lnprob_cutoff={})'.format(lnprob_cutoff)

if c is None:
pass
elif len(fitted_ps.filter(twig=c, **_skip_filter_checks).to_list()):
match_params = fitted_ps.filter(twig=c, **_skip_filter_checks)
if len(match_params) > 1:
raise ValueError("c={} matches more than one valid parameter ({})".format(c, match_params.twigs))
match_param = match_params.get_parameter()
# TODO: allow to plot from outside adopt_parameters?
match_ind = list(fitted_uniqueids).index(match_param.uniqueid)
kwargs['c'] = samples_proc_all[:, walker_ind, match_ind]
kwargs['clabel'] = match_param.twig
kwargs['cqualifier'] = match_param.qualifier
else:
# assume named color?
kwargs['c'] = c

# TODO: support for c = twig
return_ += [kwargs]

elif style in ['trace', 'walks']:
Expand All @@ -4972,7 +5002,7 @@ def _phase_wrap(phase):
fitted_uniqueids = self._bundle.get_value(qualifier='fitted_uniqueids', context='solution', solution=ps.solution, **_skip_filter_checks)
# fitted_twigs = self._bundle.get_value(qualifier='fitted_twigs', context='solution', solution=ps.solution, **_skip_filter_checks)
fitted_units = self._bundle.get_value(qualifier='fitted_units', context='solution', solution=ps.solution, **_skip_filter_checks)
fitted_ps = self._bundle.filter(uniqueid=list(adopt_uniqueids), **_skip_filter_checks)
fitted_ps = self._bundle.filter(uniqueid=list(fitted_uniqueids), **_skip_filter_checks)
lnprobabilities_proc, samples_proc = _helpers.process_mcmc_chains(lnprobabilities, samples, burnin, thin, lnprob_cutoff, adopt_inds, flatten=False)

# samples [niters, nwalkers, parameter]
Expand Down Expand Up @@ -5004,13 +5034,38 @@ def _uniqueids_for_y(fitted_ps, twig=None):
else:
plot_uniqueids = adopt_uniqueids

c = kwargs.get('c', None)
if c is not None:
_, samples_proc_all = _helpers.process_mcmc_chains(lnprobabilities, samples, burnin, thin, lnprob_cutoff, flatten=False)
kwargs_orig = _deepcopy(kwargs)
for plot_uniqueid in plot_uniqueids:
parameter_ind = list(adopt_uniqueids).index(plot_uniqueid)
_, index = _extract_index_from_string(plot_uniqueid)
yparam = fitted_ps.get_parameter(uniqueid=plot_uniqueid, **_skip_filter_checks)

for walker_ind in range(samples_proc.shape[1]):
kwargs = _deepcopy(kwargs)
kwargs = _deepcopy(kwargs_orig)

if c is None:
pass
elif c == 'lnprobabilities':
# we only need to get this once and can re-use it per-parameter/walker
kwargs['c'] = lnprobabilities_proc[:, walker_ind]
kwargs['clabel'] = _plural_to_singular_get(c)
kwargs['cqualifier'] = c
elif len(fitted_ps.filter(twig=c, **_skip_filter_checks).to_list()):
match_params = fitted_ps.filter(twig=c, **_skip_filter_checks)
if len(match_params) > 1:
raise ValueError("c={} matches more than one valid parameter ({})".format(c, match_params.twigs))
match_param = match_params.get_parameter()
# TODO: allow to plot from outside adopt_parameters?
match_ind = list(fitted_uniqueids).index(match_param.uniqueid)
kwargs['c'] = samples_proc_all[:, walker_ind, match_ind]
kwargs['clabel'] = match_param.twig
kwargs['cqualifier'] = match_param.qualifier
else:
# assume named color?
kwargs['c'] = c

# this needs to be the unflattened version
samples_y = samples_proc[:, walker_ind, parameter_ind]
Expand All @@ -5022,6 +5077,8 @@ def _uniqueids_for_y(fitted_ps, twig=None):
kwargs['ylabel'] = _corner_label(yparam, index=index)
# TODO: use fitted_units instead?
kwargs['yunit'] = fitted_units[parameter_ind]

# TODO: support for c = twig/lnprobabilities
return_ += [kwargs]
else:
raise NotImplementedError()
Expand Down

0 comments on commit f313775

Please sign in to comment.