Skip to content

Commit

Permalink
Rebase again
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Jul 26, 2017
1 parent c2fe1a3 commit ee199ef
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


class BaseHMC(arraystep.GradientSharedStep):
"""Superclass to implement Hamiltonian/hybrid monte carlo"""
"""Superclass to implement Hamiltonian/hybrid monte carlo."""

default_blocked = True

def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
Expand Down
21 changes: 21 additions & 0 deletions pymc3/step_methods/hmc/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@


class CpuLeapfrogIntegrator(object):
"""Optimized leapfrog integration using numpy."""

def __init__(self, ndim, potential, logp_dlogp_func):
"""Leapfrog integrator using CPU."""
self._ndim = ndim
self._potential = potential
self._logp_dlogp_func = logp_dlogp_func
Expand All @@ -19,6 +22,7 @@ def __init__(self, ndim, potential, logp_dlogp_func):
% (self._potential.dtype, self._dtype))

def compute_state(self, q, p):
"""Compute Hamiltonian functions using a position and momentum."""
if q.dtype != self._dtype or p.dtype != self._dtype:
raise ValueError('Invalid dtype. Must be %s' % self._dtype)
logp, dlogp = self._logp_dlogp_func(q)
Expand All @@ -28,6 +32,23 @@ def compute_state(self, q, p):
return State(q, p, v, dlogp, energy)

def step(self, epsilon, state, out=None):
"""Leapfrog integrator step.
Half a momentum update, full position update, half momentum update.
Parameters
----------
epsilon: float, > 0
step scale
state: State namedtuple,
current position data
out: (optional) State namedtuple,
preallocated arrays to write to in place
Returns
-------
None if `out` is provided, else a State namedtuple
"""
pot = self._potential
axpy = linalg.blas.get_blas_funcs('axpy', dtype=self._dtype)

Expand Down
1 change: 1 addition & 0 deletions pymc3/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def current_mean(self):

class QuadPotentialDiag(QuadPotential):
"""Quad potential using a diagonal covariance matrix."""

def __init__(self, v, dtype=None):
"""Use a vector to represent a diagonal matrix for a covariance matrix.
Expand Down

0 comments on commit ee199ef

Please sign in to comment.