Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain straight from dict / chain from numpyro with only certain fields #121

Closed
benjaminpope opened this issue Apr 13, 2024 · 7 comments · Fixed by #122
Closed

Chain straight from dict / chain from numpyro with only certain fields #121

benjaminpope opened this issue Apr 13, 2024 · 7 comments · Fixed by #122

Comments

@benjaminpope
Copy link

Hi Sam,

Great that you have a from_numpyro method in Chain - very helpful. One thing that is common though is to have numpyro track some parameters as Deterministic - for example, saving model samples at the time they are calculated, or precalculating posterior predictive draws for plotting (egs). These will in general have a different shape to the other samples of scalar parameters.

While I understand that it is straightforward to pass it into pandas and make a DataFrame, this isn't very elegant.

It would be convenient to be able to just pass a dict without pd wrapping, or to call from_numpyro with keys to use (or indeed keys to drop and others to be used).

Cheers,

Ben

@Samreay
Copy link
Owner

Samreay commented Apr 13, 2024

Sounds reasonable, for someone that's never used mumpyro, do you have some examples of this or how you'd modify the from_numpyro function to include deterministic samples?

@benjaminpope
Copy link
Author

Straight from the code I was running, together with workaround:

def sampling_fn(data_obj, model_class):

    # Define priors
    dra, ddec, log_flux = [npy.sample("Δ RA",    dist.Uniform(100, 200)), 
              npy.sample("Δ Dec",   dist.Uniform(100, 200)),
              npy.sample("log Flux",    dist.Uniform(-5,-3))
            ] 
    # TODO: how do we set appropriate priors in Cartesian coordinates? 
    # TODO: it would be good for this to be defined by the model_class itself

    flux = 10**log_flux # sample in log flux 

    # predict the data 
    pred_data = npy.deterministic("pred", data_obj.model(model_class(dra,ddec,flux)))

    data, errors = data_obj.flatten_data()
    # Sample from the posterior distribution
    model_sampler = dist.Normal(pred_data,errors 
                                )
    return npy.sample("Sampler", model_sampler, obs=data)

# Using the model above, we can now sample from the posterior distribution
# using the No U-Turn Sampler (NUTS).
sampler = npy.infer.MCMC(
    npy.infer.NUTS(sampling_fn),
    num_warmup=5000,
    num_samples=5000,
)
sampler.run(jr.PRNGKey(0), oidata_sim, BinaryModelCartesian)

posterior_pred = sampler.get_samples()['pred']

results = sampler.get_samples()
results.pop('pred') # remove the prediction from the results
cc_df = pd.DataFrame.from_dict(results)

c = ChainConsumer()  
c.add_chain(Chain(samples=cc_df,name="MCMC Results",
                                   plot_point=True,
                                    plot_cloud=True,
                                    marker_style="*",
                                    marker_size=100,
                                ))
truths = {key:val for key, val in zip(results.keys(), [np.log10(5e-4), 150., 150.])} # TODO: the ordering should be fixed
c.add_truth(Truth(location=truths))
c.plotter.plot()
  
plt.show()  

@renecotyfanboy
Copy link
Collaborator

This is a good idea, I would require this kind of feature for my project too, I'll implement this asap

@renecotyfanboy
Copy link
Collaborator

I have open a PR where you can specify arguments to keep or discard using a new parameter (for the Chain.from_arviz or Chain.from_numpyro constructors). I'll wait the review from Sam before merging it, but in the meantime you can try it using :

pip install git+https://github.com/Samreay/ChainConsumer.git@refs/pull/122/merge

and give me your feedback ! To use it, simply pass a list of string :

chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu", "sigma"]) to keep only mu and sigma parameters
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu", "~sigma"]) to keep everything but mu and sigma parameters

@Samreay
Copy link
Owner

Samreay commented Apr 16, 2024 via email

@Samreay
Copy link
Owner

Samreay commented Apr 19, 2024

Hopefully the var_names kwarg in 1.1.0 does what you'd like. Thanks @renecotyfanboy and @benjaminpope :)

@benjaminpope
Copy link
Author

Hero - thanks mate!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants