Skip to content

Commit

Permalink
Updating MR
Browse files Browse the repository at this point in the history
  • Loading branch information
Samreay committed Oct 15, 2023
1 parent 851f436 commit cae96f2
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 1,915 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-added-large-files
args: ["--maxkb=5000"]
Expand Down
21 changes: 14 additions & 7 deletions docs/examples/advanced_examples/plot_4_misc_chain_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Rather than having one example for each option, let's condense things.
"""
# %%
# # Shade Gradient
# Shade Gradient
# --------------
#
# Pretty simple - it controls how much visual difference there is in your contours.
import numpy as np
Expand All @@ -22,7 +23,8 @@
fig = c.plotter.plot()

# %%
# # Shade Alpha
# Shade Alpha
# -----------
#
# Controls how opaque the contours are. Like everything else, you
# can specify this when making the chain, or apply a single override
Expand All @@ -31,7 +33,8 @@
fig = c.plotter.plot()

# %%
# # Contour Labels
# Contour Labels
# --------------
#
# Add labels to contours. I used to have this configurable to be either
# sigma levels or percentages, but there was confusion over the 1D vs 2D sigma levels,
Expand All @@ -41,7 +44,8 @@
fig = c.plotter.plot()

# %%
# # Linestyles and widths
# Linestyles and widths
# ---------------------
#
# Fairly simple to do. To show different ones, I'll remake the chains,
# rather than having a single override. Note you *could* try something
Expand All @@ -56,15 +60,17 @@
fig = c2.plotter.plot()

# %%
# # Marker styles and sizes
# Marker styles and sizes
# -----------------------
#
# Provided you have a posterior column, you can plot the maximum probability point.

c.set_override(ChainConfig(plot_point=True, marker_style="P", marker_size=100))
fig = c.plotter.plot()

# %%
# # Cloud and Sigma Levels
# Cloud and Sigma Levels
# ----------------------
#
# Choose custom sigma levels and display point cloud.
c.set_override(
Expand All @@ -78,7 +84,8 @@
fig = c.plotter.plot()

# %%
# # Smoothing (or not)
# Smoothing (or not)
# ------------------
#
# The histograms behind the scene in ChainConsumer are smoothed. But you can turn this off.
# The higher the smoothing vaule, the more subidivisions of your bins there will be.
Expand Down
6 changes: 4 additions & 2 deletions docs/examples/plot_0_contours.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
fig = c.plotter.plot()

# %% Third cell
# # Customising Chains
# Customising Chains
# ------------------
#
# There's a lot you can configure using chains, and to make it easy, Chains are defined as pydantic
# base models so you can easily see the default and values you can pass in. Don't worry, there will be
Expand All @@ -48,7 +49,8 @@


# %% Fourth cell
# # Weights and Posteriors
# Weights and Posteriors
# ----------------------
#
# If you provide the log posteriors in the chain, you can ask for the maximum probability point
# to be plotted as well. Similarly, if you have samples with non-uniform weights, you can
Expand Down
124 changes: 124 additions & 0 deletions docs/examples/plot_5_emcee_arviz_numpyro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
# Using external samples easily
`emcee`, `arviz`, and `numpyro` are all popular MCMC packages. ChainConsumer
provides class methods to turn results from these packages into chains efficiently.
If you want to request support for another type of chain, please open a
[discussion](https://github.com/Samreay/ChainConsumer/discussions) with a code
example, and we can add it in. The brave may even provide a PR!
"""

import arviz as az
import emcee
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from scipy.stats import norm

from chainconsumer import Chain, ChainConsumer

# %%
# Emcee
# -----
#
# Let's make a dummy model here.


# Of course, your code is probably a bit more complex
def run_emcee_mcmc(n_steps, n_walkers):
rng = np.random.default_rng(42)
observed_data = rng.normal(loc=1, scale=1, size=100)

def log_likelihood(theta, data):
mu, log_sigma = theta
return np.sum(norm.logpdf(data, mu, np.exp(log_sigma)))

def log_prior(theta):
mu, log_sigma = theta
if -10 < mu < 10 and -10 < log_sigma < 10:
return 0.0
return -np.inf

def log_probability(theta, data):
lp = log_prior(theta)
if not np.isfinite(lp):
return -np.inf
return lp + log_likelihood(theta, data)

ndim = 2
p0 = rng.uniform(low=0, high=1, size=(n_walkers, ndim))
sampler = emcee.EnsembleSampler(n_walkers, ndim, log_probability, args=(observed_data,))
sampler.run_mcmc(p0, n_steps, progress=False)

return sampler


sampler = run_emcee_mcmc(8000, 16)
params = [r"$\mu$", r"$\log(\sigma)$"]
chain = Chain.from_emcee(sampler, params, "an emcee chain", discard=200, thin=2, color="indigo")
consumer = ChainConsumer().add_chain(chain)

# %%
# Let's plot the walks to make sure we've discard enough burn-in
fig = consumer.plotter.plot_walks()

# %%
# And then show the contours themselves
fig = consumer.plotter.plot()


# %%
# Numpyro
# -------
#
# Let's start with numpyro. Again, let's make a dummy model we can sample from.


def run_numpyro_mcmc(n_steps, n_chains):
rng = np.random.default_rng(42)
observed_data = rng.normal(loc=0, scale=1, size=100)

def model(data):
# Prior
mu = numpyro.sample("mu", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.HalfNormal(10))

# Likelihood
with numpyro.plate("data", size=len(data)):
numpyro.sample("obs", dist.Normal(mu, sigma), obs=data) # type: ignore

# Running MCMC
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=n_steps, num_chains=n_chains, progress_bar=False)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, data=observed_data)

return mcmc


mcmc = run_numpyro_mcmc(8000, 1)
chain = Chain.from_numpyro(mcmc, "numpyro chain", color="teal")
consumer = ChainConsumer().add_chain(chain)

# %%
# Let's plot the walks to make sure we've discard enough burn-in
fig = consumer.plotter.plot_walks()

# %%
# And then show the contours themselves
fig = consumer.plotter.plot()

# %%
# Arviz
# -----
#
# To simplify the process, we're going to make our arviz sample from
# the numpyro one.

# %%
arviz_id = az.from_numpyro(mcmc)
chain = Chain.from_arviz(arviz_id, "arviz chain", color="amber")
fig = ChainConsumer().add_chain(chain).plotter.plot()
Loading

0 comments on commit cae96f2

Please sign in to comment.