Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes and Improvements for MAP and Laplace #422

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
echo "Checking import and version number (on release)"
venv-bdist/bin/python -c "import pymc_extras as pmx; assert pmx.__version__ == '${{ github.ref_name }}'[1:] if '${{ github.ref_type }}' == 'tag' else pmx.__version__; print(pmx.__version__)"
cd ..
- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: artifact
path: dist/*
Expand All @@ -58,7 +58,7 @@ jobs:
# write id-token is necessary for trusted publishing (OIDC)
id-token: write
steps:
- uses: actions/download-artifact@v3
- uses: actions/download-artifact@v4
with:
name: artifact
path: dist
Expand Down
5 changes: 3 additions & 2 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ channels:
- conda-forge
- nodefaults
dependencies:
- pymc>=5.19.1
- pymc>=5.20
- pytest-cov>=2.5
- pytest>=3.0
- dask
- xhistogram
- statsmodels
- numba<=0.60.0
- pip
- pip:
- blackjax
- scikit-learn
- better_optimize>=0.0.10
- better_optimize
5 changes: 3 additions & 2 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ dependencies:
- dask
- xhistogram
- statsmodels
- numba<=0.60.0
- pymc>=5.20
- pip:
- pymc>=5.19.1 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
- better_optimize
104 changes: 52 additions & 52 deletions notebooks/Exponential Trend Smoothing.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from pymc_extras import gp, statespace, utils
from pymc_extras.distributions import *
from pymc_extras.inference.find_map import find_MAP
from pymc_extras.inference.fit import fit
from pymc_extras.inference.laplace import fit_laplace
from pymc_extras.model.marginal.marginal_model import (
MarginalModel,
marginalize,
Expand Down
52 changes: 36 additions & 16 deletions pymc_extras/inference/find_map.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from collections.abc import Callable
from importlib.util import find_spec
from typing import Literal, cast, get_args

import jax
import numpy as np
import pymc as pm
import pytensor
Expand All @@ -30,13 +30,29 @@
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
method_info = MINIMIZE_MODE_KWARGS[method].copy()

use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]

if use_hess and use_hessp:
_log.warning(
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
'Setting "use_hess" to False.'
)
use_hess = False

use_grad = use_grad if use_grad is not None else method_info["uses_grad"]

if use_hessp is not None and use_hess is None:
use_hess = not use_hessp

elif use_hess is not None and use_hessp is None:
use_hessp = not use_hess

elif use_hessp is None and use_hess is None:
use_hessp = method_info["uses_hessp"]
use_hess = method_info["uses_hess"]
if use_hessp and use_hess:
# If a method could use either hess or hessp, we default to using hessp
use_hess = False

return use_grad, use_hess, use_hessp


Expand All @@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
The nearest positive semi-definite matrix to the input matrix.
"""
C = (A + A.T) / 2
eigval, eigvec = np.linalg.eig(C)
eigval, eigvec = np.linalg.eigh(C)
eigval[eigval < 0] = 0

return eigvec @ np.diag(eigval) @ eigvec.T
Expand Down Expand Up @@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
return f_untransform(posterior_draws)


def _compile_jax_gradients(
def _compile_grad_and_hess_to_jax(
f_loss: Function, use_hess: bool, use_hessp: bool
) -> tuple[Callable | None, Callable | None]:
"""
Expand All @@ -122,6 +138,8 @@ def _compile_jax_gradients(
f_hessp: Callable | None
The compiled hessian-vector product function, or None if use_hessp is False.
"""
import jax

f_hess = None
f_hessp = None

Expand Down Expand Up @@ -152,7 +170,7 @@ def f_hess_jax(x):
return f_loss_and_grad, f_hess, f_hessp


def _compile_functions(
def _compile_functions_for_scipy_optimize(
loss: TensorVariable,
inputs: list[TensorVariable],
compute_grad: bool,
Expand All @@ -177,7 +195,7 @@ def _compile_functions(
compute_hessp: bool
Whether to compile a function that computes the Hessian-vector product of the loss function.
compile_kwargs: dict, optional
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
Additional keyword arguments to pass to the ``pm.compile`` function.

Returns
-------
Expand All @@ -193,19 +211,19 @@ def _compile_functions(
if compute_grad:
grads = pytensor.gradient.grad(loss, inputs)
grad = pt.concatenate([grad.ravel() for grad in grads])
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
else:
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
f_loss = pm.compile(inputs, loss, **compile_kwargs)
return [f_loss]

if compute_hess:
hess = pytensor.gradient.jacobian(grad, inputs)[0]
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
f_hess = pm.compile(inputs, hess, **compile_kwargs)

if compute_hessp:
p = pt.tensor("p", shape=inputs[0].type.shape)
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)

return [f_loss_and_grad, f_hess, f_hessp]

Expand Down Expand Up @@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
gradient_backend: str, default "pytensor"
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
compile_kwargs:
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
Additional keyword arguments to pass to the ``pm.compile`` function.

Returns
-------
Expand All @@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
)

use_jax_gradients = (gradient_backend == "jax") and use_grad
if use_jax_gradients and not find_spec("jax"):
raise ImportError("JAX must be installed to use JAX gradients")

mode = compile_kwargs.get("mode", None)
if mode is None and use_jax_gradients:
Expand All @@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
compute_hess = use_hess and not use_jax_gradients
compute_hessp = use_hessp and not use_jax_gradients

funcs = _compile_functions(
funcs = _compile_functions_for_scipy_optimize(
loss=loss,
inputs=[flat_input],
compute_grad=compute_grad,
Expand All @@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(

if use_jax_gradients:
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)

return f_loss, f_hess, f_hessp

Expand Down
27 changes: 17 additions & 10 deletions pymc_extras/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from functools import reduce
from importlib.util import find_spec
from itertools import product
from typing import Literal

Expand Down Expand Up @@ -231,7 +232,7 @@ def add_data_to_inferencedata(
return idata


def fit_mvn_to_MAP(
def fit_mvn_at_MAP(
optimized_point: dict[str, np.ndarray],
model: pm.Model | None = None,
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
Expand Down Expand Up @@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
inverse_hessian: np.ndarray
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
"""
if gradient_backend == "jax" and not find_spec("jax"):
raise ImportError("JAX must be installed to use JAX gradients")

model = pm.modelcontext(model)
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
frozen_model = freeze_dims_and_data(model)
Expand Down Expand Up @@ -344,8 +348,10 @@ def sample_laplace_posterior(

Parameters
----------
mu
H_inv
mu: RaveledVars
The MAP estimate of the model parameters.
H_inv: np.ndarray
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
model : Model
A PyMC model
chains : int
Expand Down Expand Up @@ -384,9 +390,7 @@ def sample_laplace_posterior(
constrained_rvs, replace={unconstrained_vector: batched_values}
)

f_constrain = pm.compile_pymc(
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
)
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
posterior_draws = f_constrain(posterior_draws)

else:
Expand Down Expand Up @@ -472,15 +476,17 @@ def fit_laplace(
and 1).

.. warning::
This argumnet should be considered highly experimental. It has not been verified if this method produces
This argument should be considered highly experimental. It has not been verified if this method produces
valid draws from the posterior. **Use at your own risk**.

gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
chains: int, default: 2
The number of sampling chains running in parallel.
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
compatible with the ArviZ library.
draws: int, default: 500
The number of samples to draw from the approximated posterior.
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
Expand Down Expand Up @@ -547,11 +553,12 @@ def fit_laplace(
**optimizer_kwargs,
)

mu, H_inv = fit_mvn_to_MAP(
mu, H_inv = fit_mvn_at_MAP(
optimized_point=optimized_point,
model=model,
on_bad_cov=on_bad_cov,
transform_samples=fit_in_unconstrained_space,
gradient_backend=gradient_backend,
zero_tol=zero_tol,
diag_jitter=diag_jitter,
compile_kwargs=compile_kwargs,
Expand Down
3 changes: 2 additions & 1 deletion pymc_extras/model/marginal/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
model_free_rv,
model_from_fgraph,
)
from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
from pymc.pytensorf import compile as compile_pymc
from pymc.util import RandomState, _get_seeds_per_chain
from pytensor import In, Out
from pytensor.compile import SharedVariable
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/statespace/core/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def compile_statespace(

inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))

_f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)

def f(*, draws=1, **params):
if isinstance(steps, pt.Variable):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pymc>=5.19.1
pymc>=5.20
scikit-learn
better-optimize
33 changes: 19 additions & 14 deletions tests/test_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,28 @@ def compute_z(x):


@pytest.mark.parametrize(
"method, use_grad, use_hess",
"method, use_grad, use_hess, use_hessp",
[
("nelder-mead", False, False),
("powell", False, False),
("CG", True, False),
("BFGS", True, False),
("L-BFGS-B", True, False),
("TNC", True, False),
("SLSQP", True, False),
("dogleg", True, True),
("trust-ncg", True, True),
("trust-exact", True, True),
("trust-krylov", True, True),
("trust-constr", True, True),
("nelder-mead", False, False, False),
("powell", False, False, False),
("CG", True, False, False),
("BFGS", True, False, False),
("L-BFGS-B", True, False, False),
("TNC", True, False, False),
("SLSQP", True, False, False),
("dogleg", True, True, False),
("Newton-CG", True, True, False),
("Newton-CG", True, False, True),
("trust-ncg", True, True, False),
("trust-ncg", True, False, True),
("trust-exact", True, True, False),
("trust-krylov", True, True, False),
("trust-krylov", True, False, True),
("trust-constr", True, True, False),
],
)
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
extra_kwargs = {}
if method == "dogleg":
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
Expand All @@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
**extra_kwargs,
use_grad=use_grad,
use_hess=use_hess,
use_hessp=use_hessp,
progressbar=False,
gradient_backend=gradient_backend,
compile_kwargs={"mode": "JAX"},
Expand Down
Loading