Skip to content

Commit

Permalink
Modify linear regression test to sample in the transformed (real line…
Browse files Browse the repository at this point in the history
…) space (#388)
  • Loading branch information
albcab authored and rlouf committed Oct 27, 2022
1 parent a766c22 commit 284adf0
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 49 deletions.
19 changes: 10 additions & 9 deletions examples/Introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ import blackjax

## The Problem

We'll generate observations from a normal distribution of known `loc` and `scale` to see if we can recover the parameters in sampling. Let's take a decent-size dataset with 1,000 points:
We'll generate observations from a normal distribution of known `loc` and `scale` to see if we can recover the parameters in sampling. **MCMC algorithms usually assume samples are being drawn from an unconstrained Euclidean space.** Hence why we'll log transform the scale parameter, so that sampling is done on the real line. Samples can be transformed back to their original space in post-processing. Let's take a decent-size dataset with 1,000 points:

```{code-cell} ipython3
loc, scale = 10, 20
observed = np.random.normal(loc, scale, size=1_000)
```

```{code-cell} ipython3
def logprob_fn(loc, scale, observed=observed):
def logprob_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)
Expand All @@ -51,7 +52,7 @@ logprob = lambda x: logprob_fn(**x)
### Sampler Parameters

```{code-cell} ipython3
inv_mass_matrix = np.array([0.5, 0.5])
inv_mass_matrix = np.array([0.5, 0.01])
num_integration_steps = 60
step_size = 1e-3
Expand All @@ -63,7 +64,7 @@ hmc = blackjax.hmc(logprob, step_size, inv_mass_matrix, num_integration_steps)
The initial state of the HMC algorithm requires not only an initial position, but also the potential energy and gradient of the potential energy at this position. BlackJAX provides a `new_state` function to initialize the state from an initial position.

```{code-cell} ipython3
initial_position = {"loc": 1.0, "scale": 2.0}
initial_position = {"loc": 1.0, "log_scale": 1.0}
initial_state = hmc.init(initial_position)
initial_state
```
Expand Down Expand Up @@ -100,7 +101,7 @@ rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 10_000)
loc_samples = states.position["loc"].block_until_ready()
scale_samples = states.position["scale"]
scale_samples = jnp.exp(states.position["log_scale"])
```

```{code-cell} ipython3
Expand All @@ -121,14 +122,14 @@ ax1.set_ylabel("scale")
NUTS is a *dynamic* algorithm: the number of integration steps is determined at runtime. We still need to specify a step size and a mass matrix:

```{code-cell} ipython3
inv_mass_matrix = np.array([0.5, 0.5])
inv_mass_matrix = np.array([0.5, 0.01])
step_size = 1e-3
nuts = blackjax.nuts(logprob, step_size, inv_mass_matrix)
```

```{code-cell} ipython3
initial_position = {"loc": 1.0, "scale": 2.0}
initial_position = {"loc": 1.0, "log_scale": 1.0}
initial_state = nuts.init(initial_position)
initial_state
```
Expand All @@ -139,7 +140,7 @@ rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts.step, initial_state, 4_000)
loc_samples = states.position["loc"].block_until_ready()
scale_samples = states.position["scale"]
scale_samples = jnp.exp(states.position["log_scale"])
```

```{code-cell} ipython3
Expand Down Expand Up @@ -176,7 +177,7 @@ We can use the obtained parameters to define a new kernel. Note that we do not h
states = inference_loop(rng_key, kernel, state, 1_000)
loc_samples = states.position["loc"].block_until_ready()
scale_samples = states.position["scale"]
scale_samples = jnp.exp(states.position["log_scale"])
```

```{code-cell} ipython3
Expand Down
3 changes: 1 addition & 2 deletions examples/RegimeSwitchingModel.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ dist = RegimeSwitchHMM(T, y)
```

```{code-cell} ipython3
batch_fn = jax.vmap
[n_chain, n_warm, n_iter] = [128, 5000, 200]
ksam, kinit = jrnd.split(jrnd.PRNGKey(0), 2)
dist.initialize_model(kinit, n_chain)
Expand All @@ -181,7 +180,7 @@ dist.initialize_model(kinit, n_chain)
print("Running MEADS...")
tic1 = pd.Timestamp.now()
k_warm, k_sample = jrnd.split(ksam)
warmup = blackjax.meads(dist.logprob_fn, n_chain, batch_fn=batch_fn)
warmup = blackjax.meads(dist.logprob_fn, n_chain)
init_state, kernel, _ = warmup.run(k_warm, dist.init_params, n_warm)
def one_chain(k_sam, init_state):
Expand Down
3 changes: 1 addition & 2 deletions examples/SparseLogisticRegression.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,13 @@ N_OBS, N_REG = X.shape
N_PARAM = N_REG * 2 + 1
dist = HorseshoeLogisticReg(X, y)
batch_fn = jax.vmap
[n_chain, n_warm, n_iter] = [128, 20000, 10000]
ksam, kinit = jrnd.split(jrnd.PRNGKey(0), 2)
dist.initialize_model(kinit, n_chain)
tic1 = pd.Timestamp.now()
k_warm, k_sample = jrnd.split(ksam)
warmup = blackjax.meads(dist.logprob_fn, n_chain, batch_fn=batch_fn)
warmup = blackjax.meads(dist.logprob_fn, n_chain)
adaptation_results = warmup.run(k_warm, dist.init_params, n_warm)
init_state = adaptation_results.state
kernel = adaptation_results.kernel
Expand Down
7 changes: 4 additions & 3 deletions examples/howto_sample_multiple_chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ loc, scale = 10, 20
observed = np.random.normal(loc, scale, size=1_000)
def logprob_fn(loc, scale, observed=observed):
def logprob_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)
Expand Down Expand Up @@ -83,7 +84,7 @@ To make our demonstration more dramatic we will used a NUTS sampler with poorly
import blackjax
inv_mass_matrix = np.array([0.5, 0.5])
inv_mass_matrix = np.array([0.5, 0.01])
step_size = 1e-3
nuts = blackjax.nuts(logprob, step_size, inv_mass_matrix)
Expand Down Expand Up @@ -125,7 +126,7 @@ def inference_loop_multiple_chains(
We now prepare the initial states using `jax.vmap` again, to vectorize the `init` function:

```{code-cell} ipython3
initial_positions = {"loc": np.ones(num_chains), "scale": 2.0 * np.ones(num_chains)}
initial_positions = {"loc": np.ones(num_chains), "log_scale": np.ones(num_chains)}
initial_states = jax.vmap(nuts.init, in_axes=(0))(initial_positions)
```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"fastprogress>=0.2.0",
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.4.2",
"jaxopt>=0.5.5",
],
long_description_content_type="text/markdown",
keywords="probabilistic machine learning bayesian statistics sampling algorithms",
Expand Down
7 changes: 4 additions & 3 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import blackjax


def regression_logprob(scale, coefs, preds, x):
def regression_logprob(log_scale, coefs, preds, x):
"""Linear regression"""
scale_prior = stats.expon.logpdf(scale, 1, 1)
scale = jnp.exp(log_scale)
scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
Expand Down Expand Up @@ -53,7 +54,7 @@ def run_regression(algorithm, **parameters):
is_mass_matrix_diagonal=False,
**parameters,
)
state, kernel, _ = warmup.run(warmup_key, {"scale": 1.0, "coefs": 2.0}, 1000)
state, kernel, _ = warmup.run(warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000)

states = inference_loop(kernel, 10_000, inference_key, state)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def test_minimize_lbfgs(self, maxiter, maxcor):
the same between two loop recursion algorthm of LBFGS and formulas of the
pathfinder paper"""

def regression_logprob(scale, coefs, preds, x):
def regression_logprob(log_scale, coefs, preds, x):
"""Linear regression"""
scale_prior = stats.expon.logpdf(scale, 1, 1)
scale = jnp.exp(log_scale)
scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
Expand All @@ -79,7 +80,7 @@ def regression_model(key):
return logposterior_fn

fn = regression_model(self.key)
b0 = {"scale": 1.0, "coefs": 2.0}
b0 = {"log_scale": 0.0, "coefs": 2.0}
b0_flatten, unravel_fn = ravel_pytree(b0)
objective_fn = lambda x: -fn(unravel_fn(x))
(_, status), history = self.variant(
Expand Down
37 changes: 20 additions & 17 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def irmh_proposal_distribution(rng_key):
regression_test_cases = [
{
"algorithm": blackjax.hmc,
"initial_position": {"scale": 1.0, "coefs": 2.0},
"initial_position": {"log_scale": 0.0, "coefs": 4.0},
"parameters": {"num_integration_steps": 90},
"num_warmup_steps": 1_000,
"num_sampling_steps": 3_000,
},
{
"algorithm": blackjax.nuts,
"initial_position": {"scale": 1.0, "coefs": 2.0},
"initial_position": {"log_scale": 0.0, "coefs": 4.0},
"parameters": {},
"num_warmup_steps": 1_000,
"num_sampling_steps": 1_000,
Expand All @@ -68,9 +68,10 @@ def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)

def regression_logprob(self, scale, coefs, preds, x):
def regression_logprob(self, log_scale, coefs, preds, x):
"""Linear regression"""
scale_prior = stats.expon.logpdf(scale, 1, 1)
scale = jnp.exp(log_scale)
scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
)

coefs_samples = states.position["coefs"]
scale_samples = states.position["scale"]
scale_samples = np.exp(states.position["log_scale"])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand All @@ -128,11 +129,11 @@ def test_mala(self):
warmup_key, inference_key = jax.random.split(rng_key, 2)

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "scale": 2.0})
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
states = inference_loop(mala.step, 10_000, inference_key, state)

coefs_samples = states.position["coefs"][3000:]
scale_samples = states.position["scale"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand Down Expand Up @@ -172,7 +173,7 @@ def test_pathfinder_adaptation(
states = inference_loop(kernel, num_sampling_steps, inference_key, state)

coefs_samples = states.position["coefs"]
scale_samples = states.position["scale"]
scale_samples = np.exp(states.position["log_scale"])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand All @@ -190,26 +191,28 @@ def test_meads(self):

init_key, warmup_key, inference_key = jax.random.split(rng_key, 3)

num_chains = 128
warmup = blackjax.meads(
logposterior_fn,
num_chains=128,
num_chains=num_chains,
)
scale_key, coefs_key = jax.random.split(init_key, 2)
scales = 1.0 + jax.random.normal(scale_key, (128,))
coefs = 3.0 + jax.random.normal(coefs_key, (128,))
initial_positions = {"scale": scales, "coefs": coefs}
log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,))
coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,))
initial_positions = {"log_scale": log_scales, "coefs": coefs}
last_states, kernel, _ = warmup.run(
warmup_key,
initial_positions,
num_steps=100,
num_steps=1000,
)

states = jax.vmap(
lambda state: inference_loop(kernel, 100, inference_key, state)
)(last_states)
chain_keys = jax.random.split(inference_key, num_chains)
states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))(
chain_keys, last_states
)

coefs_samples = states.position["coefs"]
scale_samples = states.position["scale"]
scale_samples = np.exp(states.position["log_scale"])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand Down
25 changes: 16 additions & 9 deletions tests/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)

def logprob_fn(self, scale, coefs, preds, x):
def logprob_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return jnp.sum(logpdf)
Expand All @@ -55,12 +56,16 @@ def test_adaptive_tempered_smc(self, N, use_log):
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

prior = lambda x: stats.expon.logpdf(x[0], 1, 1) + stats.norm.logpdf(x[1])
def prior(x):
return (
stats.expon.logpdf(jnp.exp(x[0]), 0, 1) + x[0] + stats.norm.logpdf(x[1])
)

conditioned_logprob = lambda x: self.logprob_fn(*x, **observations)

scale_init = 1 + np.random.exponential(1, N)
log_scale_init = np.log(np.random.exponential(1, N))
coeffs_init = 3 + 2 * np.random.randn(N)
smc_state_init = [scale_init, coeffs_init]
smc_state_init = [log_scale_init, coeffs_init]

iterates = []
results = [] # type: List[TemperedSMCState]
Expand Down Expand Up @@ -91,7 +96,9 @@ def test_adaptive_tempered_smc(self, N, use_log):
iterates.append(n_iter)
results.append(result)

np.testing.assert_allclose(np.mean(result.particles[0]), 1.0, rtol=1e-1)
np.testing.assert_allclose(
np.mean(np.exp(result.particles[0])), 1.0, rtol=1e-1
)
np.testing.assert_allclose(np.mean(result.particles[1]), 3.0, rtol=1e-1)

assert iterates[1] >= iterates[0]
Expand All @@ -103,12 +110,12 @@ def test_fixed_schedule_tempered_smc(self, N, n_schedule):
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

prior = lambda x: stats.norm.logpdf(jnp.log(x[0])) + stats.norm.logpdf(x[1])
prior = lambda x: stats.norm.logpdf(x[0]) + stats.norm.logpdf(x[1])
conditionned_logprob = lambda x: self.logprob_fn(*x, **observations)

scale_init = np.exp(np.random.randn(N))
log_scale_init = np.random.randn(N)
coeffs_init = np.random.randn(N)
smc_state_init = [scale_init, coeffs_init]
smc_state_init = [log_scale_init, coeffs_init]

lambda_schedule = np.logspace(-5, 0, n_schedule)
hmc_parameters = {
Expand Down Expand Up @@ -136,7 +143,7 @@ def body_fn(carry, lmbda):
return (rng_key, new_state), (new_state, info)

(_, result), _ = jax.lax.scan(body_fn, (self.key, init_state), lambda_schedule)
np.testing.assert_allclose(np.mean(result.particles[0]), 1.0, rtol=1e-1)
np.testing.assert_allclose(np.mean(np.exp(result.particles[0])), 1.0, rtol=1e-1)
np.testing.assert_allclose(np.mean(result.particles[1]), 3.0, rtol=1e-1)


Expand Down

0 comments on commit 284adf0

Please sign in to comment.