Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tuning by skipping the first samples + add new experimental tuning method #5004

Merged
merged 12 commits into from
Sep 22, 2021
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
- The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical
observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)).
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc3/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
- A small change to the default mass matrix tuning method jitter+adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
aseyboldt marked this conversation as resolved.
Show resolved Hide resolved
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
- ...

### Maintenance
Expand Down
23 changes: 23 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def sample(
random_seed=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
**kwargs,
)
if start is None:
Expand Down Expand Up @@ -2078,6 +2079,7 @@ def init_nuts(
random_seed=None,
progressbar=True,
jitter_max_retries=10,
tune=None,
**kwargs,
):
"""Set up the mass matrix initialization for NUTS.
Expand All @@ -2099,6 +2101,9 @@ def init_nuts(
as starting point.
* jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in
[-1, 1] as starting point in each chain.
* jitter+adapt_diag_grad:
An experimental initialization method that uses information from gradients and samples
during tuning.
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
sample variance of the tuning samples.
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
Expand Down Expand Up @@ -2174,6 +2179,24 @@ def init_nuts(
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
elif init == "jitter+adapt_diag_grad":
start = _init_jitter(model, model.initial_point, chains, jitter_max_retries)
mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0)
var = np.ones_like(mean)
n = len(var)

if tune is not None and tune > 200:
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
aseyboldt marked this conversation as resolved.
Show resolved Hide resolved
stop_adaptation = tune - 50
else:
stop_adaptation = None

potential = quadpotential.QuadPotentialDiagAdaptExp(
n,
mean,
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
)
elif init == "advi+adapt_diag_grad":
approx: pm.MeanField = pm.fit(
random_seed=random_seed,
Expand Down
6 changes: 2 additions & 4 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
self.start_energy = np.array(start.energy)

self.left = self.right = start
self.proposal = Proposal(
start.q.data, start.q_grad.data, start.energy, 1.0, start.model_logp
)
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp)
self.depth = 0
self.log_size = 0
self.log_weighted_accept_sum = -np.inf
Expand Down Expand Up @@ -350,7 +348,7 @@ def _single_step(self, left, epsilon):
log_size = -energy_change
proposal = Proposal(
right.q.data,
right.q_grad.data,
right.q_grad,
right.energy,
log_p_accept_weighted,
right.model_logp,
Expand Down
106 changes: 103 additions & 3 deletions pymc3/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def __init__(
adaptation_window=101,
adaptation_window_multiplier=1,
dtype=None,
discard_window=50,
aseyboldt marked this conversation as resolved.
Show resolved Hide resolved
initial_weights=None,
early_update=False,
store_mass_matrix_trace=False,
):
"""Set up a diagonal mass matrix."""
if initial_diag is not None and initial_diag.ndim != 1:
Expand All @@ -175,12 +179,20 @@ def __init__(
self.dtype = dtype
self._n = n

self._discard_window = discard_window
self._early_update = early_update

self._initial_mean = initial_mean
self._initial_diag = initial_diag
self._initial_weight = initial_weight
self.adaptation_window = adaptation_window
self.adaptation_window_multiplier = float(adaptation_window_multiplier)

self._store_mass_matrix_trace = store_mass_matrix_trace
self._mass_trace = []

self._initial_weights = initial_weights

self.reset()

def reset(self):
Expand Down Expand Up @@ -222,12 +234,18 @@ def _update_from_weightvar(self, weightvar):

def update(self, sample, grad, tune):
"""Inform the potential about a new sample during tuning."""
if self._store_mass_matrix_trace:
self._mass_trace.append(self._stds.copy())
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

if not tune:
return

self._foreground_var.add_sample(sample, weight=1)
self._background_var.add_sample(sample, weight=1)
self._update_from_weightvar(self._foreground_var)
if self._n_samples > self._discard_window:
self._foreground_var.add_sample(sample, weight=1)
self._background_var.add_sample(sample, weight=1)

if self._early_update or self._n_samples > self.adaptation_window:
self._update_from_weightvar(self._foreground_var)

if self._n_samples > 0 and self._n_samples % self.adaptation_window == 0:
self._foreground_var = self._background_var
Expand Down Expand Up @@ -342,6 +360,8 @@ def __init__(

def add_sample(self, x, weight):
x = np.asarray(x)
if weight != 1:
raise ValueError("Setting weight != 1 is not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_sample(self, x, weight):
x = np.asarray(x)
if weight != 1:
raise ValueError("Setting weight != 1 is not supported.")
def add_sample(self, x, weight=None):
if weight is not None:
warning.warn(
"Setting weight is no longer supported and and will raise an error in the future.",
DeprecationWarning,
)
x = np.asarray(x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a hard break is fine here. This really was internal, unused and wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I would suggest removing the weight argument altogether

self.n_samples += 1
old_diff = x - self.mean
self.mean[:] += old_diff / self.n_samples
Expand All @@ -360,6 +380,86 @@ def current_mean(self):
return self.mean.copy(dtype=self._dtype)


class _ExpWeightedVariance:
def __init__(self, n_vars, *, init_mean, init_var, alpha):
self._variance = init_var
self._mean = init_mean
self._alpha = alpha

def add_sample(self, value):
alpha = self._alpha
delta = value - self._mean
self._mean[...] += alpha * delta
self._variance[...] = (1 - alpha) * (self._variance + alpha * delta ** 2)

def current_variance(self, out=None):
if out is None:
out = np.empty_like(self._variance)
np.copyto(out, self._variance)
return out

def current_mean(self, out=None):
if out is None:
out = np.empty_like(self._mean)
np.copyto(out, self._mean)
return out


class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs):
super().__init__(*args, **kwargs)
self._alpha = alpha
self._use_grads = use_grads

if stop_adaptation is None:
stop_adaptation = np.inf
self._stop_adaptation = stop_adaptation

def update(self, sample, grad, tune):
if tune and self._n_samples < self._stop_adaptation:
if self._n_samples > self._discard_window:
self._variance_estimator.add_sample(sample)
if self._use_grads:
self._variance_estimator_grad.add_sample(grad)
elif self._n_samples == self._discard_window:
self._variance_estimator = _ExpWeightedVariance(
self._n,
init_mean=sample.copy(),
init_var=np.zeros_like(sample),
alpha=self._alpha,
)
if self._use_grads:
self._variance_estimator_grad = _ExpWeightedVariance(
self._n,
init_mean=grad.copy(),
init_var=np.zeros_like(grad),
alpha=self._alpha,
)

if self._n_samples > 2 * self._discard_window:
if self._use_grads:
self._update_from_variances(
self._variance_estimator, self._variance_estimator_grad
)
else:
self._update_from_weightvar(self._variance_estimator)

self._n_samples += 1

if self._store_mass_matrix_trace:
self._mass_trace.append(self._stds.copy())

def _update_from_variances(self, var_estimator, inv_var_estimator):
var = var_estimator.current_variance()
inv_var = inv_var_estimator.current_variance()
# print(inv_var)
updated = np.sqrt(var / inv_var)
self._var[:] = updated
# updated = np.exp((np.log(var) - np.log(inv_var)) / 2)
np.sqrt(updated, out=self._stds)
np.divide(1, self._stds, out=self._inv_stds)


class QuadPotentialDiag(QuadPotential):
"""Quad potential using a diagonal covariance matrix."""

Expand Down
15 changes: 13 additions & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,21 @@ def test_sample(self):

def test_sample_init(self):
with self.model:
for init in ("advi", "advi_map", "map"):
for init in (
"advi",
"advi_map",
"map",
"adapt_diag",
"jitter+adapt_diag",
"jitter+adapt_diag_grad",
"advi+adapt_diag_grad",
aseyboldt marked this conversation as resolved.
Show resolved Hide resolved
"advi+adapt_diag",
"adapt_full",
"jitter+adapt_full",
):
pm.sample(
init=init,
tune=0,
tune=120,
n_init=1000,
draws=50,
random_seed=self.random_seed,
Expand Down