Skip to content

Commit

Permalink
Make Metropolis cope better with multiple dimensions
Browse files Browse the repository at this point in the history
Metropolis now updates each dimension sequentially and tunes a proposal scale parameter per dimension
  • Loading branch information
ricardoV94 committed May 30, 2022
1 parent 57654dc commit 616cd90
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 36 deletions.
121 changes: 85 additions & 36 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def __init__(
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]

vars = pm.inputvars(vars)

initial_values_shape = [initial_values[v.name].shape for v in vars]
if S is None:
S = np.ones(sum(initial_values[v.name].size for v in vars))
S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape)))

if proposal_dist is not None:
self.proposal_dist = proposal_dist(S)
Expand All @@ -186,7 +188,6 @@ def __init__(
self.tune = tune
self.tune_interval = tune_interval
self.steps_until_tune = tune_interval
self.accepted = 0

# Determine type of variables
self.discrete = np.concatenate(
Expand All @@ -195,11 +196,33 @@ def __init__(
self.any_discrete = self.discrete.any()
self.all_discrete = self.discrete.all()

# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(
scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted
# Metropolis will try to handle one batched dimension at a time This, however,
# is not safe for discrete multivariate distributions (looking at you Multinomial),
# due to high dependency among the support dimensions. For continuous multivariate
# distributions we assume they are being transformed in a way that makes each
# dimension semi-independent.
is_scalar = len(initial_values_shape) == 1 and initial_values_shape[0] == ()
self.elemwise_update = not (
is_scalar
or (
self.any_discrete
and max(getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars)
> 0
)
)
if self.elemwise_update:
dims = int(sum(np.prod(ivs) for ivs in initial_values_shape))
else:
dims = 1
self.enum_dims = np.arange(dims, dtype=int)
self.accept_rate_iter = np.zeros(dims, dtype=float)
self.accepted_iter = np.zeros(dims, dtype=bool)
self.accepted_sum = np.zeros(dims, dtype=int)

# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval)

# TODO: This is not being used when compiling the logp function!
self.mode = mode

shared = pm.make_shared_replacements(initial_values, vars, model)
Expand All @@ -210,6 +233,7 @@ def reset_tuning(self):
"""Resets the tuned sampler parameters to their initial values."""
for attr, initial_value in self._untuned_settings.items():
setattr(self, attr, initial_value)
self.accepted_sum[:] = 0
return

def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
Expand All @@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:

if not self.steps_until_tune and self.tune:
# Tune scaling parameter
self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval))
self.scaling = tune(self.scaling, self.accepted_sum / float(self.tune_interval))
# Reset counter
self.steps_until_tune = self.tune_interval
self.accepted = 0
self.accepted_sum[:] = 0

delta = self.proposal_dist() * self.scaling

Expand All @@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
else:
q = floatX(q0 + delta)

accept = self.delta_logp(q, q0)
q_new, accepted = metrop_select(accept, q, q0)

self.accepted += accepted
if self.elemwise_update:
q_temp = q0.copy()
# Shuffle order of updates (probably we don't need to do this in every step)
np.random.shuffle(self.enum_dims)
for i in self.enum_dims:
q_temp[i] = q[i]
accept_rate_i = self.delta_logp(q_temp, q0)
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0)
q_temp[i] = q_temp_[i]
self.accept_rate_iter[i] = accept_rate_i
self.accepted_iter[i] = accepted_i
self.accepted_sum[i] += accepted_i
q = q_temp
else:
accept_rate = self.delta_logp(q, q0)
q, accepted = metrop_select(accept_rate, q, q0)
self.accept_rate_iter = accept_rate
self.accepted_iter = accepted
self.accepted_sum += accepted

self.steps_until_tune -= 1

stats = {
"tune": self.tune,
"scaling": self.scaling,
"accept": np.exp(accept),
"accepted": accepted,
"scaling": np.mean(self.scaling),
"accept": np.mean(np.exp(self.accept_rate_iter)),
"accepted": np.mean(self.accepted_iter),
}

q_new = RaveledVars(q_new, point_map_info)

return q_new, [stats]
return RaveledVars(q, point_map_info), [stats]

@staticmethod
def competence(var, has_grad):
Expand All @@ -275,26 +312,38 @@ def tune(scale, acc_rate):
>0.95 x 10
"""
if acc_rate < 0.001:
return scale * np.where(
acc_rate < 0.001,
# reduce by 90 percent
return scale * 0.1
elif acc_rate < 0.05:
# reduce by 50 percent
return scale * 0.5
elif acc_rate < 0.2:
# reduce by ten percent
return scale * 0.9
elif acc_rate > 0.95:
# increase by factor of ten
return scale * 10.0
elif acc_rate > 0.75:
# increase by double
return scale * 2.0
elif acc_rate > 0.5:
# increase by ten percent
return scale * 1.1

return scale
0.1,
np.where(
acc_rate < 0.05,
# reduce by 50 percent
0.5,
np.where(
acc_rate < 0.2,
# reduce by ten percent
0.9,
np.where(
acc_rate > 0.95,
# increase by factor of ten
10.0,
np.where(
acc_rate > 0.75,
# increase by double
2.0,
np.where(
acc_rate > 0.5,
# increase by ten percent
1.1,
# Do not change
1.0,
),
),
),
),
),
)


class BinaryMetropolis(ArrayStep):
Expand Down
36 changes: 36 additions & 0 deletions pymc/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
Beta,
Binomial,
Categorical,
Dirichlet,
HalfNormal,
Multinomial,
MvNormal,
Normal,
)
Expand Down Expand Up @@ -403,6 +405,40 @@ def test_tuning_reset(self):
assert tuned != 0.1
np.testing.assert_array_equal(idata.sample_stats["scaling"].sel(chain=c).values, tuned)

@pytest.mark.parametrize(
"batched_dist",
(
Binomial.dist(n=5, p=0.9), # scalar case
Binomial.dist(n=np.arange(40) + 1, p=np.linspace(0.1, 0.9, 40), shape=(40,)),
Binomial.dist(
n=(np.arange(20) + 1)[::-1],
p=np.linspace(0.1, 0.9, 20),
shape=(
2,
20,
),
),
Dirichlet.dist(a=np.ones(3) * (np.arange(40) + 1)[:, None], shape=(40, 3)),
Dirichlet.dist(a=np.ones(3) * (np.arange(20) + 1)[:, None], shape=(2, 20, 3)),
),
)
def test_elemwise_update(self, batched_dist):
with Model() as m:
m.register_rv(batched_dist, name="batched_dist")
step = pm.Metropolis([batched_dist])
assert step.elemwise_update == (batched_dist.ndim > 0)
trace = pm.sample(draws=1000, chains=2, step=step)

assert az.rhat(trace).max()["batched_dist"].values < 1.1
assert az.ess(trace).min()["batched_dist"].values > 50

def test_multinomial_no_elemwise_update(self):
with Model() as m:
batched_dist = Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4))
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
step = pm.Metropolis([batched_dist])
assert not step.elemwise_update


class TestDEMetropolisZ:
def test_tuning_lambda_sequential(self):
Expand Down

0 comments on commit 616cd90

Please sign in to comment.