Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify linear regressions to parameters in R #388

Merged
merged 1 commit into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/Introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ import blackjax

## The Problem

We'll generate observations from a normal distribution of known `loc` and `scale` to see if we can recover the parameters in sampling. Let's take a decent-size dataset with 1,000 points:
We'll generate observations from a normal distribution of known `loc` and `scale` to see if we can recover the parameters in sampling. **MCMC algorithms usually assume samples are being drawn from an unconstrained Euclidean space.** Hence why we'll log transform the scale parameter, so that sampling is done on the real line. Samples can be transformed back to their original space in post-processing. Let's take a decent-size dataset with 1,000 points:

```{code-cell} ipython3
loc, scale = 10, 20
observed = np.random.normal(loc, scale, size=1_000)
```

```{code-cell} ipython3
def logprob_fn(loc, scale, observed=observed):
def logprob_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)

Expand All @@ -51,7 +52,7 @@ logprob = lambda x: logprob_fn(**x)
### Sampler Parameters

```{code-cell} ipython3
inv_mass_matrix = np.array([0.5, 0.5])
inv_mass_matrix = np.array([0.5, 0.01])
num_integration_steps = 60
step_size = 1e-3

Expand All @@ -63,7 +64,7 @@ hmc = blackjax.hmc(logprob, step_size, inv_mass_matrix, num_integration_steps)
The initial state of the HMC algorithm requires not only an initial position, but also the potential energy and gradient of the potential energy at this position. BlackJAX provides a `new_state` function to initialize the state from an initial position.

```{code-cell} ipython3
initial_position = {"loc": 1.0, "scale": 2.0}
initial_position = {"loc": 1.0, "log_scale": 1.0}
initial_state = hmc.init(initial_position)
initial_state
```
Expand Down Expand Up @@ -100,7 +101,7 @@ rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 10_000)

loc_samples = states.position["loc"].block_until_ready()
scale_samples = states.position["scale"]
scale_samples = jnp.exp(states.position["log_scale"])
```

```{code-cell} ipython3
Expand All @@ -121,14 +122,14 @@ ax1.set_ylabel("scale")
NUTS is a *dynamic* algorithm: the number of integration steps is determined at runtime. We still need to specify a step size and a mass matrix:

```{code-cell} ipython3
inv_mass_matrix = np.array([0.5, 0.5])
inv_mass_matrix = np.array([0.5, 0.01])
step_size = 1e-3

nuts = blackjax.nuts(logprob, step_size, inv_mass_matrix)
```

```{code-cell} ipython3
initial_position = {"loc": 1.0, "scale": 2.0}
initial_position = {"loc": 1.0, "log_scale": 1.0}
initial_state = nuts.init(initial_position)
initial_state
```
Expand All @@ -139,7 +140,7 @@ rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts.step, initial_state, 4_000)

loc_samples = states.position["loc"].block_until_ready()
scale_samples = states.position["scale"]
scale_samples = jnp.exp(states.position["log_scale"])
```

```{code-cell} ipython3
Expand Down Expand Up @@ -176,7 +177,7 @@ We can use the obtained parameters to define a new kernel. Note that we do not h
states = inference_loop(rng_key, kernel, state, 1_000)

loc_samples = states.position["loc"].block_until_ready()
scale_samples = states.position["scale"]
scale_samples = jnp.exp(states.position["log_scale"])
```

```{code-cell} ipython3
Expand Down
3 changes: 1 addition & 2 deletions examples/RegimeSwitchingModel.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ dist = RegimeSwitchHMM(T, y)
```

```{code-cell} ipython3
batch_fn = jax.vmap
[n_chain, n_warm, n_iter] = [128, 5000, 200]
ksam, kinit = jrnd.split(jrnd.PRNGKey(0), 2)
dist.initialize_model(kinit, n_chain)
Expand All @@ -181,7 +180,7 @@ dist.initialize_model(kinit, n_chain)
print("Running MEADS...")
tic1 = pd.Timestamp.now()
k_warm, k_sample = jrnd.split(ksam)
warmup = blackjax.meads(dist.logprob_fn, n_chain, batch_fn=batch_fn)
warmup = blackjax.meads(dist.logprob_fn, n_chain)
init_state, kernel, _ = warmup.run(k_warm, dist.init_params, n_warm)

def one_chain(k_sam, init_state):
Expand Down
3 changes: 1 addition & 2 deletions examples/SparseLogisticRegression.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,13 @@ N_OBS, N_REG = X.shape
N_PARAM = N_REG * 2 + 1
dist = HorseshoeLogisticReg(X, y)

batch_fn = jax.vmap
[n_chain, n_warm, n_iter] = [128, 20000, 10000]
ksam, kinit = jrnd.split(jrnd.PRNGKey(0), 2)
dist.initialize_model(kinit, n_chain)

tic1 = pd.Timestamp.now()
k_warm, k_sample = jrnd.split(ksam)
warmup = blackjax.meads(dist.logprob_fn, n_chain, batch_fn=batch_fn)
warmup = blackjax.meads(dist.logprob_fn, n_chain)
adaptation_results = warmup.run(k_warm, dist.init_params, n_warm)
init_state = adaptation_results.state
kernel = adaptation_results.kernel
Expand Down
7 changes: 4 additions & 3 deletions examples/howto_sample_multiple_chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ loc, scale = 10, 20
observed = np.random.normal(loc, scale, size=1_000)


def logprob_fn(loc, scale, observed=observed):
def logprob_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)

Expand Down Expand Up @@ -83,7 +84,7 @@ To make our demonstration more dramatic we will used a NUTS sampler with poorly
import blackjax


inv_mass_matrix = np.array([0.5, 0.5])
inv_mass_matrix = np.array([0.5, 0.01])
step_size = 1e-3

nuts = blackjax.nuts(logprob, step_size, inv_mass_matrix)
Expand Down Expand Up @@ -125,7 +126,7 @@ def inference_loop_multiple_chains(
We now prepare the initial states using `jax.vmap` again, to vectorize the `init` function:

```{code-cell} ipython3
initial_positions = {"loc": np.ones(num_chains), "scale": 2.0 * np.ones(num_chains)}
initial_positions = {"loc": np.ones(num_chains), "log_scale": np.ones(num_chains)}
initial_states = jax.vmap(nuts.init, in_axes=(0))(initial_positions)
```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"fastprogress>=0.2.0",
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.4.2",
"jaxopt>=0.5.5",
],
long_description_content_type="text/markdown",
keywords="probabilistic machine learning bayesian statistics sampling algorithms",
Expand Down
7 changes: 4 additions & 3 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import blackjax


def regression_logprob(scale, coefs, preds, x):
def regression_logprob(log_scale, coefs, preds, x):
"""Linear regression"""
scale_prior = stats.expon.logpdf(scale, 1, 1)
scale = jnp.exp(log_scale)
scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
Expand Down Expand Up @@ -53,7 +54,7 @@ def run_regression(algorithm, **parameters):
is_mass_matrix_diagonal=False,
**parameters,
)
state, kernel, _ = warmup.run(warmup_key, {"scale": 1.0, "coefs": 2.0}, 1000)
state, kernel, _ = warmup.run(warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000)

states = inference_loop(kernel, 10_000, inference_key, state)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def test_minimize_lbfgs(self, maxiter, maxcor):
the same between two loop recursion algorthm of LBFGS and formulas of the
pathfinder paper"""

def regression_logprob(scale, coefs, preds, x):
def regression_logprob(log_scale, coefs, preds, x):
"""Linear regression"""
scale_prior = stats.expon.logpdf(scale, 1, 1)
scale = jnp.exp(log_scale)
scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
Expand All @@ -79,7 +80,7 @@ def regression_model(key):
return logposterior_fn

fn = regression_model(self.key)
b0 = {"scale": 1.0, "coefs": 2.0}
b0 = {"log_scale": 0.0, "coefs": 2.0}
b0_flatten, unravel_fn = ravel_pytree(b0)
objective_fn = lambda x: -fn(unravel_fn(x))
(_, status), history = self.variant(
Expand Down
37 changes: 20 additions & 17 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def irmh_proposal_distribution(rng_key):
regression_test_cases = [
{
"algorithm": blackjax.hmc,
"initial_position": {"scale": 1.0, "coefs": 2.0},
"initial_position": {"log_scale": 0.0, "coefs": 4.0},
"parameters": {"num_integration_steps": 90},
"num_warmup_steps": 1_000,
"num_sampling_steps": 3_000,
},
{
"algorithm": blackjax.nuts,
"initial_position": {"scale": 1.0, "coefs": 2.0},
"initial_position": {"log_scale": 0.0, "coefs": 4.0},
"parameters": {},
"num_warmup_steps": 1_000,
"num_sampling_steps": 1_000,
Expand All @@ -68,9 +68,10 @@ def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)

def regression_logprob(self, scale, coefs, preds, x):
def regression_logprob(self, log_scale, coefs, preds, x):
"""Linear regression"""
scale_prior = stats.expon.logpdf(scale, 1, 1)
scale = jnp.exp(log_scale)
scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
)

coefs_samples = states.position["coefs"]
scale_samples = states.position["scale"]
scale_samples = np.exp(states.position["log_scale"])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand All @@ -128,11 +129,11 @@ def test_mala(self):
warmup_key, inference_key = jax.random.split(rng_key, 2)

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "scale": 2.0})
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
states = inference_loop(mala.step, 10_000, inference_key, state)

coefs_samples = states.position["coefs"][3000:]
scale_samples = states.position["scale"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand Down Expand Up @@ -172,7 +173,7 @@ def test_pathfinder_adaptation(
states = inference_loop(kernel, num_sampling_steps, inference_key, state)

coefs_samples = states.position["coefs"]
scale_samples = states.position["scale"]
scale_samples = np.exp(states.position["log_scale"])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand All @@ -190,26 +191,28 @@ def test_meads(self):

init_key, warmup_key, inference_key = jax.random.split(rng_key, 3)

num_chains = 128
warmup = blackjax.meads(
logposterior_fn,
num_chains=128,
num_chains=num_chains,
)
scale_key, coefs_key = jax.random.split(init_key, 2)
scales = 1.0 + jax.random.normal(scale_key, (128,))
coefs = 3.0 + jax.random.normal(coefs_key, (128,))
initial_positions = {"scale": scales, "coefs": coefs}
log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,))
coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,))
initial_positions = {"log_scale": log_scales, "coefs": coefs}
last_states, kernel, _ = warmup.run(
warmup_key,
initial_positions,
num_steps=100,
num_steps=1000,
)

states = jax.vmap(
lambda state: inference_loop(kernel, 100, inference_key, state)
)(last_states)
chain_keys = jax.random.split(inference_key, num_chains)
states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))(
chain_keys, last_states
)

coefs_samples = states.position["coefs"]
scale_samples = states.position["scale"]
scale_samples = np.exp(states.position["log_scale"])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)
Expand Down
25 changes: 16 additions & 9 deletions tests/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)

def logprob_fn(self, scale, coefs, preds, x):
def logprob_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return jnp.sum(logpdf)
Expand All @@ -55,12 +56,16 @@ def test_adaptive_tempered_smc(self, N, use_log):
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

prior = lambda x: stats.expon.logpdf(x[0], 1, 1) + stats.norm.logpdf(x[1])
def prior(x):
return (
stats.expon.logpdf(jnp.exp(x[0]), 0, 1) + x[0] + stats.norm.logpdf(x[1])
)

conditioned_logprob = lambda x: self.logprob_fn(*x, **observations)

scale_init = 1 + np.random.exponential(1, N)
log_scale_init = np.log(np.random.exponential(1, N))
coeffs_init = 3 + 2 * np.random.randn(N)
smc_state_init = [scale_init, coeffs_init]
smc_state_init = [log_scale_init, coeffs_init]

iterates = []
results = [] # type: List[TemperedSMCState]
Expand Down Expand Up @@ -91,7 +96,9 @@ def test_adaptive_tempered_smc(self, N, use_log):
iterates.append(n_iter)
results.append(result)

np.testing.assert_allclose(np.mean(result.particles[0]), 1.0, rtol=1e-1)
np.testing.assert_allclose(
np.mean(np.exp(result.particles[0])), 1.0, rtol=1e-1
)
np.testing.assert_allclose(np.mean(result.particles[1]), 3.0, rtol=1e-1)

assert iterates[1] >= iterates[0]
Expand All @@ -103,12 +110,12 @@ def test_fixed_schedule_tempered_smc(self, N, n_schedule):
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

prior = lambda x: stats.norm.logpdf(jnp.log(x[0])) + stats.norm.logpdf(x[1])
prior = lambda x: stats.norm.logpdf(x[0]) + stats.norm.logpdf(x[1])
conditionned_logprob = lambda x: self.logprob_fn(*x, **observations)

scale_init = np.exp(np.random.randn(N))
log_scale_init = np.random.randn(N)
coeffs_init = np.random.randn(N)
smc_state_init = [scale_init, coeffs_init]
smc_state_init = [log_scale_init, coeffs_init]

lambda_schedule = np.logspace(-5, 0, n_schedule)
hmc_parameters = {
Expand Down Expand Up @@ -136,7 +143,7 @@ def body_fn(carry, lmbda):
return (rng_key, new_state), (new_state, info)

(_, result), _ = jax.lax.scan(body_fn, (self.key, init_state), lambda_schedule)
np.testing.assert_allclose(np.mean(result.particles[0]), 1.0, rtol=1e-1)
np.testing.assert_allclose(np.mean(np.exp(result.particles[0])), 1.0, rtol=1e-1)
np.testing.assert_allclose(np.mean(result.particles[1]), 3.0, rtol=1e-1)


Expand Down