Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate alternative jax based samplers available since bambi: 0.15.0 #618

Open
AlexanderFengler opened this issue Dec 31, 2024 · 5 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@AlexanderFengler
Copy link
Collaborator

Version 0.15.0 of Bambi adds support for a lot of extra samplers, via deeper integration with the bayeux package.

We need to properly tests that we support this (this will essentially be a robustification exercise on the behavior of our likelihoods).

To close this issue, essentially try to replicate this notebook via HSSM.

@AlexanderFengler AlexanderFengler added enhancement New feature or request help wanted Extra attention is needed labels Dec 31, 2024
@cpaniaguam
Copy link
Collaborator

I am using the following script to reproduce the workflow in the notebook with a hssm model. Below is the output using the blackjax_nuts, tfp_nuts, numpyro_nuts, flowmc_realnvp_hmc, and flowmc_realnvp_hmc backends. While things seem to run successfully, I am uncertain how to ensure their correctness and interpret the output. Could you provide any insights or take-aways? What more robust tests could be performed? If everything ends up checking out, what would the next steps be? @AlexanderFengler @digicosmos86 @krishnbera

Click to show the script
# test-hssm-inference.PY

# bayeux-ml package needed
import pymc as pm
import hssm
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--inference-method", type=str, default="blackjax_nuts")

args = parser.parse_args()

# Load a package-supplied dataset
cav_data = hssm.load_data("cavanagh_theta")

# Define a basic hierarchical model with trial-level covariates
model = hssm.HSSM(
    model="ddm",
    data=cav_data,
    include=[
        {
            "name": "v",
            "prior": {
                "Intercept": {"name": "Normal", "mu": 0.0, "sigma": 0.1},
                "theta": {"name": "Normal", "mu": 0.0, "sigma": 0.1},
            },
            "formula": "v ~ theta + (1|participant_id)",
            "link": "identity",
        },
    ],
)

# these are used in the notebook --  not sure if they are relevant here
kwargs = {
    "adapt.run": {"num_steps": 500},
    "num_chains": 4,
    "num_draws": 250,
    "num_adapt_draws": 250,
}


with model.pymc_model:
    pm.sample(inference_method=args.inference_method, **kwargs)

blackjax_nuts

Click to show output
✗ python test-hssm-inference.py --inference-method blackjax_nuts
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future.
Model initialized successfully.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, z, t, v_Intercept, v_theta, v_1|participant_id_mu, v_1|participant_id_sigma, v_1|participant_id_offset]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:08:45
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 526 seconds.

tfp_nuts

Click to show output
✗ python test-hssm-inference.py --inference-method tfp_nuts
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future.
Model initialized successfully.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, z, t, v_Intercept, v_theta, v_1|participant_id_mu, v_1|participant_id_sigma, v_1|participant_id_offset]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━╸━━  70% 0:04:… / 
Sampling 4 chains, 0 divergences ━━━━ 100% 0:0… / 0:17…
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1058 seconds.

numpyro_nuts

Click to show output
✗ python test-hssm-inference.py --inference-method numpyro_nuts
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future.
Model initialized successfully.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [z, a, t, v_Intercept, v_theta, v_1|participant_id_mu, v_1|participant_id_sigma, v_1|participant_id_offset]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━
Sampling 4 chains, 0 divergences ━━━━╺━━  59% 0:06:… / 
Sampling 4 chains, 1 divergences ━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:17:55
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1076 seconds.
There were 1 divergences after tuning. Increase `target_accept` or reparameterize.

flowmc_realnvp_hmc

Click to show output
✗ python test-hssm-inference.py --inference-method flowmc_realnvp_hmc
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future.
Model initialized successfully.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, z, t, v_Intercept, v_theta, v_1|participant_id_mu, v_1|participant_id_sigma, v_1|participant_id_offset]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━╸
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━╸
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━╸
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━╸
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━╸
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━╸
Sampling 4 chains, 0 divergences ━━━╸━━━  54% 0:07:… / 
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:17:23
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1044 seconds.

flowmc_realnvp_hmc

Click to show output
✗ python test-hssm-inference.py --inference-method nutpie         
No common intercept. Bounds for parameter v is not applied due to a current limitation of Bambi. This will change in the future.
Model initialized successfully.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, t, z, v_Intercept, v_theta, v_1|participant_id_mu, v_1|participant_id_sigma, v_1|participant_id_offset]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━╺━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━╺━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━╺━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━╺━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━╺━━━
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━╺━━━
Sampling 4 chains, 0 divergences ━━━╺━━━  46% 0:08:… / 
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:17:15
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1035 seconds.

@cpaniaguam
Copy link
Collaborator

One thing I can distill is that blackjax_nuts is about 2x faster than the others. Is this related to #608, @digicosmos86?

@AlexanderFengler
Copy link
Collaborator Author

@cpaniaguam I think what happened here is that the inference_method argument to pm.sample got ignored.

I will have to double check this, but from the outputs it looks quite suspicious, in the sense that it looks like the basic PyMC NUTS sampler was applied throughout.

Note, in the notebook attached in the issue, what is called is model.fit() which is the bambi wrapper for sampling. This does a bunch of stuff before passing through a sampler.

Our sample() method, hssm_model.sample() actually calls .fit() internally.

If you do,

with hssm_model.pymc_model:
    pm.sample()

this doesn't happen.

I will have to double check details, but just to leave a reaction here.

@cpaniaguam
Copy link
Collaborator

I tried using pm.fit but it doesn't seem to expose an option to use backends from bayeax.

Using fit with the method argument like this

with model.pymc_model:
    pm.fit(method=args.inference_method, **kwargs)

produces this error:

KeyError: "method should be one of {'asvgd', 'svgd', 'fullrank_advi', 'advi'} or Inference instance

If one passes a value for inference_method instead, one gets this error:

TypeError: ObjectiveFunction.step_function() got an unexpected keyword argument 'inference_method'

What do you think? @AlexanderFengler

@AlexanderFengler
Copy link
Collaborator Author

Ok @cpaniaguam let's talk about this during the next meeting.
Should be something simple in principle. The real problems should come up only downstream.

I will look at it, but I remember that we had reached the point of routing through bayeux before, just that it ended up not working because of shape issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants