-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate alternative jax based samplers available since bambi: 0.15.0 #618
Comments
I am using the following script to reproduce the workflow in the notebook with a hssm model. Below is the output using the Click to show the script# test-hssm-inference.PY
# bayeux-ml package needed
import pymc as pm
import hssm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--inference-method", type=str, default="blackjax_nuts")
args = parser.parse_args()
# Load a package-supplied dataset
cav_data = hssm.load_data("cavanagh_theta")
# Define a basic hierarchical model with trial-level covariates
model = hssm.HSSM(
model="ddm",
data=cav_data,
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Normal", "mu": 0.0, "sigma": 0.1},
"theta": {"name": "Normal", "mu": 0.0, "sigma": 0.1},
},
"formula": "v ~ theta + (1|participant_id)",
"link": "identity",
},
],
)
# these are used in the notebook -- not sure if they are relevant here
kwargs = {
"adapt.run": {"num_steps": 500},
"num_chains": 4,
"num_draws": 250,
"num_adapt_draws": 250,
}
with model.pymc_model:
pm.sample(inference_method=args.inference_method, **kwargs) blackjax_nutsClick to show output
tfp_nutsClick to show output
numpyro_nutsClick to show output
flowmc_realnvp_hmcClick to show output
flowmc_realnvp_hmcClick to show output
|
One thing I can distill is that blackjax_nuts is about 2x faster than the others. Is this related to #608, @digicosmos86? |
@cpaniaguam I think what happened here is that the I will have to double check this, but from the outputs it looks quite suspicious, in the sense that it looks like the basic PyMC NUTS sampler was applied throughout. Note, in the Our If you do,
this doesn't happen. I will have to double check details, but just to leave a reaction here. |
I tried using Using with model.pymc_model:
pm.fit(method=args.inference_method, **kwargs) produces this error:
If one passes a value for
What do you think? @AlexanderFengler |
Ok @cpaniaguam let's talk about this during the next meeting. I will look at it, but I remember that we had reached the point of routing through |
Version
0.15.0
of Bambi adds support for a lot of extra samplers, via deeper integration with the bayeux package.We need to properly tests that we support this (this will essentially be a robustification exercise on the behavior of our likelihoods).
To close this issue, essentially try to replicate this notebook via HSSM.
The text was updated successfully, but these errors were encountered: