Skip to content

Commit

Permalink
Fix numerical issue in NUTS acceptance ratio computation
Browse files Browse the repository at this point in the history
Should reduce the frequency of getting "Mass matrix contains zeros on the diagonal" during warmup and fix a bunch of issue in https://github.com/pymc-devs/pymc3/issues/3959
  • Loading branch information
junpenglao committed Jun 16, 2020
1 parent 8d241cd commit ae6c1bb
Showing 1 changed file with 98 additions and 69 deletions.
167 changes: 98 additions & 69 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pymc3.theanof import floatX
from pymc3.vartypes import continuous_types

__all__ = ['NUTS']
__all__ = ["NUTS"]


def logbern(log_p):
Expand All @@ -33,8 +33,25 @@ def logbern(log_p):
return np.log(nr.uniform()) < log_p


def log1mexp_numpy(x):
"""Return log(1 - exp(-x)).
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
return np.where(
x < 0.683,
np.log(-np.expm1(-x)),
np.log1p(-np.exp(-x)))


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
return a + log1mexp_numpy(a - b)


class NUTS(BaseHMC):
R"""A sampler for continuous variables based on Hamiltonian mechanics.
r"""A sampler for continuous variables based on Hamiltonian mechanics.
NUTS automatically tunes the step size and the number of steps per
sample. A detailed description can be found at [1], "Algorithm 6:
Expand Down Expand Up @@ -84,27 +101,28 @@ class NUTS(BaseHMC):
Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.
"""

name = 'nuts'
name = "nuts"

default_blocked = True
generates_stats = True
stats_dtypes = [{
'depth': np.int64,
'step_size': np.float64,
'tune': np.bool,
'mean_tree_accept': np.float64,
'step_size_bar': np.float64,
'tree_size': np.float64,
'diverging': np.bool,
'energy_error': np.float64,
'energy': np.float64,
'max_energy_error': np.float64,
'model_logp': np.float64,
}]

def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8,
**kwargs):
R"""Set up the No-U-Turn sampler.
stats_dtypes = [
{
"depth": np.int64,
"step_size": np.float64,
"tune": np.bool,
"mean_tree_accept": np.float64,
"step_size_bar": np.float64,
"tree_size": np.float64,
"diverging": np.bool,
"energy_error": np.float64,
"energy": np.float64,
"max_energy_error": np.float64,
"model_logp": np.float64,
}
]

def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
r"""Set up the No-U-Turn sampler.
Parameters
----------
Expand Down Expand Up @@ -184,7 +202,7 @@ def _hamiltonian_step(self, start, p0, step_size):
self._reached_max_treedepth += 1

stats = tree.stats()
accept_stat = stats['mean_tree_accept']
accept_stat = stats["mean_tree_accept"]
return HMCStepData(tree.proposal, accept_stat, divergence_info, stats)

@staticmethod
Expand All @@ -200,10 +218,11 @@ def warnings(self):
n_treedepth = self._reached_max_treedepth

if n_samples > 0 and n_treedepth / float(n_samples) > 0.05:
msg = ('The chain reached the maximum tree depth. Increase '
'max_treedepth, increase target_accept or reparameterize.')
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn',
None, None, None)
msg = (
"The chain reached the maximum tree depth. Increase "
"max_treedepth, increase target_accept or reparameterize."
)
warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None)
warnings.append(warn)
return warnings

Expand All @@ -213,8 +232,8 @@ def warnings(self):

# A subtree of the binary tree built by nuts.
Subtree = namedtuple(
"Subtree",
"left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals")
"Subtree", "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
)


class _Tree:
Expand Down Expand Up @@ -242,11 +261,12 @@ def __init__(self, ndim, integrator, start, step_size, Emax):

self.left = self.right = start
self.proposal = Proposal(
start.q, start.q_grad, start.energy, 1.0, start.model_logp)
start.q, start.q_grad, start.energy, 1.0, start.model_logp
)
self.depth = 0
self.log_size = 0
self.log_accept_sum = -np.inf
self.mean_tree_accept = 0.
self.mean_tree_accept = 0.0
self.n_proposals = 0
self.p_sum = start.p.copy()
self.max_energy_change = 0
Expand All @@ -265,15 +285,17 @@ def extend(self, direction):
"""
if direction > 0:
tree, diverging, turning = self._build_subtree(
self.right, self.depth, floatX(np.asarray(self.step_size)))
self.right, self.depth, floatX(np.asarray(self.step_size))
)
leftmost_begin, leftmost_end = self.left, self.right
rightmost_begin, rightmost_end = tree.left, tree.right
leftmost_p_sum = self.p_sum
rightmost_p_sum = tree.p_sum
self.right = tree.right
else:
tree, diverging, turning = self._build_subtree(
self.left, self.depth, floatX(np.asarray(-self.step_size)))
self.left, self.depth, floatX(np.asarray(-self.step_size))
)
leftmost_begin, leftmost_end = tree.right, tree.left
rightmost_begin, rightmost_end = self.left, self.right
leftmost_p_sum = tree.p_sum
Expand All @@ -291,8 +313,7 @@ def extend(self, direction):
self.proposal = tree.proposal

self.log_size = np.logaddexp(self.log_size, tree.log_size)
self.log_accept_sum = np.logaddexp(self.log_accept_sum,
tree.log_accept_sum)
self.log_accept_sum = np.logaddexp(self.log_accept_sum, tree.log_accept_sum)
self.p_sum[:] += tree.p_sum

# Additional turning check only when tree depth > 0 to avoid redundant work
Expand All @@ -301,10 +322,14 @@ def extend(self, direction):
p_sum = self.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
p_sum1 = leftmost_p_sum + rightmost_begin.p
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (
p_sum1.dot(rightmost_begin.v) <= 0
)
p_sum2 = leftmost_end.p + rightmost_p_sum
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
turning = (turning | turning1 | turning2)
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (
p_sum2.dot(rightmost_end.v) <= 0
)
turning = turning | turning1 | turning2

return diverging, turning

Expand All @@ -324,21 +349,23 @@ def _single_step(self, left, epsilon):
if np.abs(energy_change) > np.abs(self.max_energy_change):
self.max_energy_change = energy_change
if np.abs(energy_change) < self.Emax:
# Acceptance statistic
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
# Saturated Metropolis accept probability with Boltzmann weight
# if h - H0 < 0
log_p_accept = -energy_change + min(0., -energy_change)
# Acceptance statistic
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
# Saturated Metropolis accept probability with Boltzmann weight
# if h - H0 < 0
log_p_accept = -energy_change + min(0.0, -energy_change)
log_size = -energy_change
proposal = Proposal(
right.q, right.q_grad, right.energy, log_p_accept,
right.model_logp)
tree = Subtree(right, right, right.p,
proposal, log_size, log_p_accept, 1)
right.q, right.q_grad, right.energy, log_p_accept, right.model_logp
)
tree = Subtree(
right, right, right.p, proposal, log_size, log_p_accept, 1
)
return tree, None, False
else:
error_msg = ("Energy change in leapfrog step is too large: %s."
% energy_change)
error_msg = (
"Energy change in leapfrog step is too large: %s." % energy_change
)
error = None
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
divergance_info = DivergenceInfo(error_msg, error, left)
Expand All @@ -348,13 +375,11 @@ def _build_subtree(self, left, depth, epsilon):
if depth == 0:
return self._single_step(left, epsilon)

tree1, diverging, turning = self._build_subtree(
left, depth - 1, epsilon)
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
if diverging or turning:
return tree1, diverging, turning

tree2, diverging, turning = self._build_subtree(
tree1.right, depth - 1, epsilon)
tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)

left, right = tree1.left, tree2.right

Expand All @@ -364,14 +389,17 @@ def _build_subtree(self, left, depth, epsilon):
# Additional U turn check only when depth > 1 to avoid redundant work.
if depth - 1 > 0:
p_sum1 = tree1.p_sum + tree2.left.p
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (
p_sum1.dot(tree2.left.v) <= 0
)
p_sum2 = tree1.right.p + tree2.p_sum
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
turning = (turning | turning1 | turning2)
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (
p_sum2.dot(tree2.right.v) <= 0
)
turning = turning | turning1 | turning2

log_size = np.logaddexp(tree1.log_size, tree2.log_size)
log_accept_sum = np.logaddexp(tree1.log_accept_sum,
tree2.log_accept_sum)
log_accept_sum = np.logaddexp(tree1.log_accept_sum, tree2.log_accept_sum)
if logbern(tree2.log_size - log_size):
proposal = tree2.proposal
else:
Expand All @@ -384,23 +412,24 @@ def _build_subtree(self, left, depth, epsilon):

n_proposals = tree1.n_proposals + tree2.n_proposals

tree = Subtree(left, right, p_sum, proposal,
log_size, log_accept_sum, n_proposals)
tree = Subtree(
left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals
)
return tree, diverging, turning

def stats(self):
# Update accept stat if any subtrees were accepted
if self.log_size > 0:
# Remove contribution from initial state which is always a perfect
# accept
sum_weight = np.expm1(self.log_size)
self.mean_tree_accept = np.exp(self.log_accept_sum) / sum_weight
# Remove contribution from initial state which is always a perfect
# accept
log_sum_weight = logdiffexp_numpy(self.log_size, 0.)
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight)
return {
'depth': self.depth,
'mean_tree_accept': self.mean_tree_accept,
'energy_error': self.proposal.energy - self.start.energy,
'energy': self.proposal.energy,
'tree_size': self.n_proposals,
'max_energy_error': self.max_energy_change,
'model_logp': self.proposal.logp,
"depth": self.depth,
"mean_tree_accept": self.mean_tree_accept,
"energy_error": self.proposal.energy - self.start.energy,
"energy": self.proposal.energy,
"tree_size": self.n_proposals,
"max_energy_error": self.max_energy_change,
"model_logp": self.proposal.logp,
}

0 comments on commit ae6c1bb

Please sign in to comment.