Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convenient function to access inference methods and kwargs #795

Merged
merged 9 commits into from
Apr 15, 2024

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Apr 6, 2024

Closes #791. This PR adds a convenient class InferenceMethods that allows users to access the available inference methods and kwargs.

For example, the inference methods

bmb.inference_methods.names
{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},
 'bayeux': {'mcmc': ['tfp_hmc',
   'tfp_nuts',
   'tfp_snaper_hmc',
   'blackjax_hmc',
   'blackjax_chees_hmc',
   'blackjax_meads_hmc',
   'blackjax_nuts',
   'blackjax_hmc_pathfinder',
   'blackjax_nuts_pathfinder',
   'flowmc_rqspline_hmc',
   'flowmc_rqspline_mala',
   'flowmc_realnvp_hmc',
   'flowmc_realnvp_mala',
   'numpyro_hmc',
   'numpyro_nuts']}}

and the default kwargs for a given inference method

bmb.inference_methods.get_kwargs("tfp_nuts")
{'extra_parameters': {'num_draws': 1000,
  'num_chains': 8,
  'num_adaptation_steps': 500,
  'return_pytree': False},
 'dual_averaging_kwargs': {'target_accept_prob': 0.8,
  'exploration_shrinkage': 0.05,
  'shrinkage_target': None,
  'step_count_smoothing': 10,
  'decay_rate': 0.75,
  'step_size_setter_fn': <function tensorflow_probability.substrates.jax.mcmc.simple_step_size_adaptation.hmc_like_step_size_setter_fn(kernel_results, new_step_size)>,
  'step_size_getter_fn': <function tensorflow_probability.substrates.jax.mcmc.simple_step_size_adaptation.hmc_like_step_size_getter_fn(kernel_results)>,
  'log_accept_prob_getter_fn': <function tensorflow_probability.substrates.jax.mcmc.simple_step_size_adaptation.hmc_like_log_accept_prob_getter_fn(kernel_results)>,
  'reduce_fn': <function tensorflow_probability.substrates.jax.math.generic.reduce_log_harmonic_mean_exp(input_tensor, axis=None, keepdims=False, experimental_named_axis=None, experimental_allow_all_gather=False, name=None)>,
  'experimental_reduce_chain_axis_names': None,
  'validate_args': False,
  'name': None,
  'num_adaptation_steps': 500},
 'proposal_kernel_kwargs': {'max_tree_depth': 10,
  'max_energy_diff': 1000.0,
  'unrolled_leapfrog_steps': 1,
  'parallel_iterations': 10,
  'experimental_shard_axis_names': None,
  'name': None,
  'step_size': 0.5}}

Additionally, this convenience class is now imported and used in backend/pymc.py to obtain the bayeux and pymc inference methods. I have updated relevant doc strings and the alternative samplers notebook as well.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ColCarroll
Copy link
Collaborator

Note https://github.com/jax-ml/bayeux/blob/main/bayeux/_src/shared.py#L115 is what I use in bayeux, though I do a fair amount of manual work cleaning things up, or removing arguments that are supplied elsewhere.

If you run the slightly modified

def get_default_signature(fn):
  defaults = {}
  for key, val in inspect.signature(fn).parameters.items():
    if val.default is not inspect.Signature.empty:
      defaults[key] = val.default
  return defaults

on pm.sample, you get the pleasant

{'draws': 1000,
 'tune': 1000,
 'chains': None,
 'cores': None,
 'random_seed': None,
 'progressbar': True,
 'step': None,
 'nuts_sampler': 'pymc',
 'initvals': None,
 'init': 'auto',
 'jitter_max_retries': 10,
 'n_init': 200000,
 'trace': None,
 'discard_tuned_samples': True,
 'compute_convergence_checks': True,
 'keep_warning_stat': False,
 'return_inferencedata': True,
 'idata_kwargs': None,
 'nuts_sampler_kwargs': None,
 'callback': None,
 'mp_ctx': None,
 'model': None}

@tomicapretto
Copy link
Collaborator

As far as I know the signature for pm.sample() has arguments for many different things. Maybe we can hard-code the subset of parameters we want to query from it and only report those?

@GStechschulte GStechschulte marked this pull request as ready for review April 13, 2024 06:04
@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Apr 14, 2024

I am not so sure tests should be added for this? The .get_kwargs method already raises an error if the user passes an inference method that is not in the list of available methods.

Then, for bmb.inference_methods.name I suppose a test could be added to assert specific key names (mcmc, vi) exist in the dict?

@tomicapretto
Copy link
Collaborator

@GStechschulte I see what you mean. I don't have a strong opinion here. The only thing I can add is that if we leave it untested it'll decrease coverage. I know high coverage does not mean our test suite is perfect, but I do think that in general lower coverage is worse. We could omit the inference_methods.py module from coverage but I'm not sure if it is a good thing or not.

Another option would be to merge as it is and open an issue so someone tests this in the future (as it's not critical).

Copy link
Collaborator

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mostly reviewed just the bayeux code, which LGTM!

@GStechschulte
Copy link
Collaborator Author

Many thanks for the review @ColCarroll and @tomicapretto

I added a small test to check that the keys (mcmc, vi) exist when calling bmb.inference_methods.names. As well as a test to ensure that a ValueError is raised if a user passes an unsupported inference method name.

@GStechschulte GStechschulte merged commit b5aefcf into bambinos:main Apr 15, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add convenient function to access list of inference methods
3 participants