Skip to content

Commit

Permalink
Small updates to the change of variable example
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 12, 2022
1 parent 5b6e262 commit becd2d2
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions examples/change_of_variable_hmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.14.1
kernelspec:
display_name: Python 3.9.7 ('blackjax')
display_name: Python 3 (ipykernel)
language: python
name: python3
mystnb:
Expand All @@ -28,7 +28,6 @@ In particular we use following binomial hierarchical model where $y_{j}$ and $N_
\end{align}
```


```{code-cell} ipython3
:tags: [hide-cell]
Expand All @@ -37,7 +36,6 @@ import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
pd.set_option("display.max_rows", 80)
Expand Down Expand Up @@ -217,21 +215,29 @@ n_rat_tumors = len(group_size)
```{code-cell} ipython3
:tags: [hide-input]
plt.figure(figsize=(12, 3))
plt.bar(range(n_rat_tumors), n_of_positives)
fig = plt.figure(figsize=(12, 3))
ax = fig.add_subplot(111)
ax.bar(range(n_rat_tumors), n_of_positives)
ax.set_xlabel("tumor type", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.title("No. of positives for each tumor type", fontsize=14)
plt.xlabel("tumor type", fontsize=12)
sns.despine()
```

```{code-cell} ipython3
:tags: [hide-input]
plt.figure(figsize=(14, 4))
plt.bar(range(n_rat_tumors), group_size)
fig = plt.figure(figsize=(14, 4))
ax = fig.add_subplot(111)
ax.bar(range(n_rat_tumors), group_size)
plt.title("Group size for each tumor type", fontsize=14)
plt.xlabel("tumor type", fontsize=12)
sns.despine()
ax.set_xlabel("tumor type", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
```

## Posterior Sampling
Expand Down Expand Up @@ -298,7 +304,7 @@ def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = call_warmup(keys, init_params)
initial_states, tuned_params = jax.jit(call_warmup)(keys, init_params)
```

Now we write inference loop for multiple chains
Expand All @@ -312,7 +318,6 @@ def inference_loop_multiple_chains(
def kernel(key, state, **params):
return step_fn(key, state, log_prob_fn, **params)
@jax.jit
def one_step(states, rng_key):
keys = jax.random.split(rng_key, num_chains)
states, infos = jax.vmap(kernel)(keys, states, **tuned_params)
Expand Down Expand Up @@ -428,7 +433,7 @@ def joint_logprob_change_of_var(params):
return logprob_ab + logprob_thetas + logprob_y + log_det_jacob
```

except change of variable in `joint_logprob()` function, everthing will remain same
except for the change of variable in `joint_logprob()` function, everthing will remain same

```{code-cell} ipython3
rng_key = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -555,6 +560,7 @@ init_key, warmup_key = jax.random.split(rng_key, 2)
init_params = bijectors.inverse(pinned.sample_unpinned(n_chains, seed=init_key))
keys = jax.random.split(warmup_key, n_chains)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
Expand Down

0 comments on commit becd2d2

Please sign in to comment.