Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder w pytensor symbolic #387

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

aphc14
Copy link

@aphc14 aphc14 commented Oct 31, 2024

Another version to draft PR #386 which uses more of PyTensor's symbolic variables and compiling functions.

Questions for Review

  1. Which implementations should I continue for future improvements?
  2. Are there additional PyTensor optimisations we could leverage?

`fit_pathfinder`
- Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs.
- Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'.
- Initial points are automatically set to jitter as jitter is required for pathfinder.

Extras
- New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder.

Tests
- Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata  are consistent for a given random seed.
Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder.
Add a PyMC/PyTensor implementation of Pathfinder VI as an alternative to the existing BlackJAX backend. Key changes include:

- Implement core Pathfinder components using PyTensor with batched operations
- Add inference_backend parameter to select between PyMC and BlackJAX implementations
- Enable jittering of initial points for Pathfinder
@aphc14
Copy link
Author

aphc14 commented Nov 4, 2024

Suppose the preferred approach is to stick with symbolic variables in PyTensor than the other non-symbolic approach in #386. In that case, I'd be happy to refactor the Multipath Pathfinder implementation in #386 to use symbolic variables and pytensor.function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant