-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chain straight from dict / chain from numpyro with only certain fields #121
Comments
Sounds reasonable, for someone that's never used mumpyro, do you have some examples of this or how you'd modify the from_numpyro function to include deterministic samples? |
Straight from the code I was running, together with workaround: def sampling_fn(data_obj, model_class):
# Define priors
dra, ddec, log_flux = [npy.sample("Δ RA", dist.Uniform(100, 200)),
npy.sample("Δ Dec", dist.Uniform(100, 200)),
npy.sample("log Flux", dist.Uniform(-5,-3))
]
# TODO: how do we set appropriate priors in Cartesian coordinates?
# TODO: it would be good for this to be defined by the model_class itself
flux = 10**log_flux # sample in log flux
# predict the data
pred_data = npy.deterministic("pred", data_obj.model(model_class(dra,ddec,flux)))
data, errors = data_obj.flatten_data()
# Sample from the posterior distribution
model_sampler = dist.Normal(pred_data,errors
)
return npy.sample("Sampler", model_sampler, obs=data)
# Using the model above, we can now sample from the posterior distribution
# using the No U-Turn Sampler (NUTS).
sampler = npy.infer.MCMC(
npy.infer.NUTS(sampling_fn),
num_warmup=5000,
num_samples=5000,
)
sampler.run(jr.PRNGKey(0), oidata_sim, BinaryModelCartesian)
posterior_pred = sampler.get_samples()['pred']
results = sampler.get_samples()
results.pop('pred') # remove the prediction from the results
cc_df = pd.DataFrame.from_dict(results)
c = ChainConsumer()
c.add_chain(Chain(samples=cc_df,name="MCMC Results",
plot_point=True,
plot_cloud=True,
marker_style="*",
marker_size=100,
))
truths = {key:val for key, val in zip(results.keys(), [np.log10(5e-4), 150., 150.])} # TODO: the ordering should be fixed
c.add_truth(Truth(location=truths))
c.plotter.plot()
plt.show() |
This is a good idea, I would require this kind of feature for my project too, I'll implement this asap |
I have open a PR where you can specify arguments to keep or discard using a new parameter (for the Chain.from_arviz or Chain.from_numpyro constructors). I'll wait the review from Sam before merging it, but in the meantime you can try it using :
and give me your feedback ! To use it, simply pass a list of string :
|
I'll try and get in this asap if we get discharged tomorrow, otherwise
might be a few days. Thanks for doing the hard part, Simon!
…On Tue, 16 Apr 2024, 9:54 pm Simon Dupourqué, ***@***.***> wrote:
I have open a PR where you can specify arguments to keep or discard using
a new parameter (for the Chain.from_arviz or Chain.from_numpyro
constructors). I'll wait the review from Sam before merging it, but in the
meantime you can try it using :
pip install ***@***.***/pull/122/merge
and give me your feedback ! To use it, simply pass a list of string :
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu",
"sigma"]) to keep only mu and sigma parameters
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu",
"~sigma"]) to keep everything but mu and sigma parameters
—
Reply to this email directly, view it on GitHub
<#121 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABTPSWK6XVUIMQYIJTS3V7LY5UGPTAVCNFSM6AAAAABGFMZ4CWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANJYHEYTCMBVGY>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Hopefully the |
Hero - thanks mate! |
Hi Sam,
Great that you have a
from_numpyro
method inChain
- very helpful. One thing that is common though is to have numpyro track some parameters asDeterministic
- for example, saving model samples at the time they are calculated, or precalculating posterior predictive draws for plotting (egs). These will in general have a different shape to the other samples of scalar parameters.While I understand that it is straightforward to pass it into pandas and make a DataFrame, this isn't very elegant.
It would be convenient to be able to just pass a dict without
pd
wrapping, or to callfrom_numpyro
with keys to use (or indeed keys to drop and others to be used).Cheers,
Ben
The text was updated successfully, but these errors were encountered: