diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index c03226279f..017a65e248 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -9,7 +9,8 @@ class BaseHMC(arraystep.GradientSharedStep): - """Superclass to implement Hamiltonian/hybrid monte carlo""" + """Superclass to implement Hamiltonian/hybrid monte carlo.""" + default_blocked = True def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, diff --git a/pymc3/step_methods/hmc/integration.py b/pymc3/step_methods/hmc/integration.py index 04605071b1..8ae76c5602 100644 --- a/pymc3/step_methods/hmc/integration.py +++ b/pymc3/step_methods/hmc/integration.py @@ -8,7 +8,10 @@ class CpuLeapfrogIntegrator(object): + """Optimized leapfrog integration using numpy.""" + def __init__(self, ndim, potential, logp_dlogp_func): + """Leapfrog integrator using CPU.""" self._ndim = ndim self._potential = potential self._logp_dlogp_func = logp_dlogp_func @@ -19,6 +22,7 @@ def __init__(self, ndim, potential, logp_dlogp_func): % (self._potential.dtype, self._dtype)) def compute_state(self, q, p): + """Compute Hamiltonian functions using a position and momentum.""" if q.dtype != self._dtype or p.dtype != self._dtype: raise ValueError('Invalid dtype. Must be %s' % self._dtype) logp, dlogp = self._logp_dlogp_func(q) @@ -28,6 +32,23 @@ def compute_state(self, q, p): return State(q, p, v, dlogp, energy) def step(self, epsilon, state, out=None): + """Leapfrog integrator step. + + Half a momentum update, full position update, half momentum update. + + Parameters + ---------- + epsilon: float, > 0 + step scale + state: State namedtuple, + current position data + out: (optional) State namedtuple, + preallocated arrays to write to in place + + Returns + ------- + None if `out` is provided, else a State namedtuple + """ pot = self._potential axpy = linalg.blas.get_blas_funcs('axpy', dtype=self._dtype) diff --git a/pymc3/step_methods/hmc/quadpotential.py b/pymc3/step_methods/hmc/quadpotential.py index 0f3319d820..1041f5e168 100644 --- a/pymc3/step_methods/hmc/quadpotential.py +++ b/pymc3/step_methods/hmc/quadpotential.py @@ -262,6 +262,7 @@ def current_mean(self): class QuadPotentialDiag(QuadPotential): """Quad potential using a diagonal covariance matrix.""" + def __init__(self, v, dtype=None): """Use a vector to represent a diagonal matrix for a covariance matrix.