diff --git a/pymc3/blocking.py b/pymc3/blocking.py index 940265b7e1..301d66952e 100644 --- a/pymc3/blocking.py +++ b/pymc3/blocking.py @@ -23,14 +23,28 @@ class ArrayOrdering(object): def __init__(self, vars): self.vmap = [] - dim = 0 + self._by_name = {} + size = 0 for var in vars: - slc = slice(dim, dim + var.dsize) - self.vmap.append(VarMap(str(var), slc, var.dshape, var.dtype)) - dim += var.dsize + name = var.name + if name is None: + raise ValueError('Unnamed variable in ArrayOrdering.') + if name in self._by_name: + raise ValueError('Name of variable not unique: %s.' % name) + if not hasattr(var, 'dshape') or not hasattr(var, 'dsize'): + raise ValueError('Shape of variable not known %s' % name) + + slc = slice(size, size + var.dsize) + varmap = VarMap(name, slc, var.dshape, var.dtype) + self.vmap.append(varmap) + self._by_name[name] = varmap + size += var.dsize + + self.size = size - self.dimensions = dim + def __getitem__(self, key): + return self._by_name[key] class DictToArrayBijection(object): @@ -58,7 +72,7 @@ def map(self, dpt): ---------- dpt : dict """ - apt = np.empty(self.ordering.dimensions, dtype=self.array_dtype) + apt = np.empty(self.ordering.size, dtype=self.array_dtype) for var, slc, _, _ in self.ordering.vmap: apt[slc] = dpt[var].ravel() return apt @@ -125,7 +139,7 @@ def __init__(self, list_arrays, intype='numpy'): dim += array.size count += 1 - self.dimensions = dim + self.size = dim class ListToArrayBijection(object): @@ -158,7 +172,7 @@ def fmap(self, list_arrays): single array comprising all the input arrays """ - array = np.empty(self.ordering.dimensions) + array = np.empty(self.ordering.size) for list_ind, slc, _, _, _ in self.ordering.vmap: array[slc] = list_arrays[list_ind].ravel() return array diff --git a/pymc3/model.py b/pymc3/model.py index 525db36992..cd6f7b82dc 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -184,9 +184,12 @@ def fastd2logp(self, vars=None): def logpt(self): """Theano scalar of log-probability of the model""" if getattr(self, 'total_size', None) is not None: - return tt.sum(self.logp_elemwiset) * self.scaling + logp = tt.sum(self.logp_elemwiset) * self.scaling else: - return tt.sum(self.logp_elemwiset) + logp = tt.sum(self.logp_elemwiset) + if self.name is not None: + logp.name = '__logp_%s' % self.name + return logp class InitContextMeta(type): @@ -277,6 +280,173 @@ def tree_contains(self, item): return dict.__contains__(self, item) +class ValueGradFunction(object): + """Create a theano function that computes a value and its gradient. + + Parameters + ---------- + cost : theano variable + The value that we compute with its gradient. + grad_vars : list of named theano variables or None + The arguments with respect to which the gradient is computed. + extra_args : list of named theano variables or None + Other arguments of the function that are assumed constant. They + are stored in shared variables and can be set using + `set_extra_values`. + dtype : str, default=theano.config.floatX + The dtype of the arrays. + casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no' + Casting rule for casting `grad_args` to the array dtype. + See `numpy.can_cast` for a description of the options. + Keep in mind that we cast the variables to the array *and* + back from the array dtype to the variable dtype. + kwargs + Extra arguments are passed on to `theano.function`. + + Attributes + ---------- + size : int + The number of elements in the parameter array. + profile : theano profiling object or None + The profiling object of the theano function that computes value and + gradient. This is None unless `profile=True` was set in the + kwargs. + """ + def __init__(self, cost, grad_vars, extra_vars=None, dtype=None, + casting='no', **kwargs): + if extra_vars is None: + extra_vars = [] + + names = [arg.name for arg in grad_vars + extra_vars] + if any(name is None for name in names): + raise ValueError('Arguments must be named.') + if len(set(names)) != len(names): + raise ValueError('Names of the arguments are not unique.') + + if cost.ndim > 0: + raise ValueError('Cost must be a scalar.') + + self._grad_vars = grad_vars + self._extra_vars = extra_vars + self._extra_var_names = set(var.name for var in extra_vars) + self._cost = cost + self._ordering = ArrayOrdering(grad_vars) + self.size = self._ordering.size + self._extra_are_set = False + if dtype is None: + dtype = theano.config.floatX + self.dtype = dtype + for var in self._grad_vars: + if not np.can_cast(var.dtype, self.dtype, casting): + raise TypeError('Invalid dtype for variable %s. Can not ' + 'cast to %s with casting rule %s.' + % (var.name, self.dtype, casting)) + if not np.issubdtype(var.dtype, float): + raise TypeError('Invalid dtype for variable %s. Must be ' + 'floating point but is %s.' + % (var.name, var.dtype)) + + givens = [] + self._extra_vars_shared = {} + for var in extra_vars: + shared = theano.shared(var.tag.test_value, var.name + '_shared__') + self._extra_vars_shared[var.name] = shared + givens.append((var, shared)) + + self._vars_joined, self._cost_joined = self._build_joined( + self._cost, grad_vars, self._ordering.vmap) + + grad = tt.grad(self._cost_joined, self._vars_joined) + grad.name = '__grad' + + inputs = [self._vars_joined] + + self._theano_function = theano.function( + inputs, [self._cost_joined, grad], givens=givens, **kwargs) + + def set_extra_values(self, extra_vars): + self._extra_are_set = True + for var in self._extra_vars: + self._extra_vars_shared[var.name].set_value(extra_vars[var.name]) + + def get_extra_values(self): + if not self._extra_are_set: + raise ValueError('Extra values are not set.') + + return {var.name: self._extra_vars_shared[var.name].get_value() + for var in self._extra_vars} + + def __call__(self, array, grad_out=None, extra_vars=None): + if extra_vars is not None: + self.set_extra_values(extra_vars) + + if not self._extra_are_set: + raise ValueError('Extra values are not set.') + + if array.shape != (self.size,): + raise ValueError('Invalid shape for array. Must be %s but is %s.' + % ((self.size,), array.shape)) + + if grad_out is None: + out = np.empty_like(array) + else: + out = grad_out + + logp, dlogp = self._theano_function(array) + if grad_out is None: + return logp, dlogp + else: + out[...] = dlogp + return logp + + @property + def profile(self): + """Profiling information of the underlying theano function.""" + return self._theano_function.profile + + def dict_to_array(self, point): + """Convert a dictionary with values for grad_vars to an array.""" + array = np.empty(self.size, dtype=self.dtype) + for varmap in self._ordering.vmap: + array[varmap.slc] = point[varmap.var].ravel().astype(self.dtype) + return array + + def array_to_dict(self, array): + """Convert an array to a dictionary containing the grad_vars.""" + if array.shape != (self.size,): + raise ValueError('Array should have shape (%s,) but has %s' + % (self.size, array.shape)) + if array.dtype != self.dtype: + raise ValueError('Array has invalid dtype. Should be %s but is %s' + % (self._dtype, self.dtype)) + point = {} + for varmap in self._ordering.vmap: + data = array[varmap.slc].reshape(varmap.shp) + point[varmap.var] = data.astype(varmap.dtyp) + + return point + + def array_to_full_dict(self, array): + """Convert an array to a dictionary with grad_vars and extra_vars.""" + point = self.array_to_dict(array) + for name, var in self._extra_vars_shared.items(): + point[name] = var.get_value() + return point + + def _build_joined(self, cost, args, vmap): + args_joined = tt.vector('__args_joined') + args_joined.tag.test_value = np.zeros(self.size, dtype=self.dtype) + + joined_slices = {} + for vmap in vmap: + sliced = args_joined[vmap.slc].reshape(vmap.shp) + sliced.name = vmap.var + joined_slices[vmap.var] = sliced + + replace = {var: joined_slices[var.name] for var in args} + return args_joined, theano.clone(cost, replace=replace) + + class Model(six.with_metaclass(InitContextMeta, Context, Factor)): """Encapsulates the variables and likelihood factors of a model. @@ -419,7 +589,6 @@ def bijection(self): return bij @property - @memoize def dict_to_array(self): return self.bijection.map @@ -428,23 +597,34 @@ def ndim(self): return sum(var.dsize for var in self.free_RVs) @property - @memoize def logp_array(self): return self.bijection.mapf(self.fastlogp) @property - @memoize def dlogp_array(self): vars = inputvars(self.cont_vars) return self.bijection.mapf(self.fastdlogp(vars)) + def logp_dlogp_function(self, grad_vars=None, **kwargs): + if grad_vars is None: + grad_vars = list(typefilter(self.free_RVs, continuous_types)) + else: + for var in grad_vars: + if var.dtype not in continuous_types: + raise ValueError("Can only compute the gradient of " + "continuous types: %s" % var) + varnames = [var.name for var in grad_vars] + extra_vars = [var for var in self.free_RVs if var.name not in varnames] + return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs) + @property - @memoize def logpt(self): """Theano scalar of log-probability of the model""" with self: factors = [var.logpt for var in self.basic_RVs] + self.potentials - return tt.add(*map(tt.sum, factors)) + logp = tt.add(*map(tt.sum, factors)) + logp.name = '__logp' + return logp @property def varlogpt(self): @@ -595,7 +775,6 @@ def __getitem__(self, key): except KeyError: raise e - @memoize def makefn(self, outs, mode=None, *args, **kwargs): """Compiles a Theano function which returns `outs` and takes the variable ancestors of `outs` as inputs. diff --git a/pymc3/step_methods/arraystep.py b/pymc3/step_methods/arraystep.py index 69c0a7aa6e..aead81647f 100644 --- a/pymc3/step_methods/arraystep.py +++ b/pymc3/step_methods/arraystep.py @@ -156,6 +156,30 @@ def step(self, point): return bij.rmap(apoint) +class GradientSharedStep(BlockedStep): + def __init__(self, vars, model=None, blocked=True, + dtype=None, **theano_kwargs): + model = modelcontext(model) + self.vars = vars + self.blocked = blocked + + self._logp_dlogp_func = model.logp_dlogp_function( + vars, dtype=dtype, **theano_kwargs) + + def step(self, point): + self._logp_dlogp_func.set_extra_values(point) + array = self._logp_dlogp_func.dict_to_array(point) + + if self.generates_stats: + apoint, stats = self.astep(array) + point = self._logp_dlogp_func.array_to_full_dict(apoint) + return point, stats + else: + apoint = self.astep(array) + point = self._logp_dlogp_func.array_to_full_dict(apoint) + return point + + def metrop_select(mr, q, q0): """Perform rejection/acceptance step for Metropolis class samplers. diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index cbc702b1f9..23c0a8d67f 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -1,35 +1,33 @@ -from ..arraystep import ArrayStepShared -from .trajectory import get_theano_hamiltonian_functions - -from pymc3.tuning import guess_scaling from pymc3.model import modelcontext, Point +from pymc3.step_methods import arraystep from .quadpotential import quad_potential, QuadPotentialDiagAdapt -from pymc3.theanof import inputvars, make_shared_replacements, floatX +from pymc3.step_methods.hmc import integration +from pymc3.theanof import inputvars, floatX +from pymc3.tuning import guess_scaling import numpy as np -class BaseHMC(ArrayStepShared): +class BaseHMC(arraystep.GradientSharedStep): default_blocked = True def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, - model=None, blocked=True, use_single_leapfrog=False, - potential=None, integrator="leapfrog", **theano_kwargs): + model=None, blocked=True, potential=None, + integrator="leapfrog", dtype=None, **theano_kwargs): """Superclass to implement Hamiltonian/hybrid monte carlo Parameters ---------- vars : list of theano variables scaling : array_like, ndim = {1,2} - Scaling for momentum distribution. 1d arrays interpreted matrix diagonal. + Scaling for momentum distribution. 1d arrays interpreted matrix + diagonal. step_scale : float, default=0.25 Size of steps to take, automatically scaled down by 1/n**(1/4) is_cov : bool, default=False - Treat scaling as a covariance matrix/vector if True, else treat it as a - precision matrix/vector - model : pymc3 Model instance. default=Context model - blocked: Boolean, default True - use_single_leapfrog: Boolean, will leapfrog steps take a single step at a time. - default False. + Treat scaling as a covariance matrix/vector if True, else treat + it as a precision matrix/vector + model : pymc3 Model instance + blocked: bool, default=True potential : Potential, optional An object that represents the Hamiltonian with methods `velocity`, `energy`, and `random` methods. @@ -41,8 +39,12 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, vars = model.cont_vars vars = inputvars(vars) + super(BaseHMC, self).__init__(vars, blocked=blocked, model=model, + dtype=dtype, **theano_kwargs) + + size = self._logp_dlogp_func.size + if scaling is None and potential is None: - size = sum(np.prod(var.dshape, dtype=int) for var in vars) mean = floatX(np.zeros(size)) var = floatX(np.ones(size)) potential = QuadPotentialDiagAdapt(size, mean, var, 10) @@ -54,17 +56,11 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, if scaling is not None and potential is not None: raise ValueError("Can not specify both potential and scaling.") - self.step_size = step_scale / (model.ndim ** 0.25) + self.step_size = step_scale / (size ** 0.25) if potential is not None: self.potential = potential else: self.potential = quad_potential(scaling, is_cov) - shared = make_shared_replacements(vars, model) - if theano_kwargs is None: - theano_kwargs = {} - - self.H, self.compute_energy, self.compute_velocity, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions( - vars, shared, model.logpt, self.potential, use_single_leapfrog, integrator, **theano_kwargs) - - super(BaseHMC, self).__init__(vars, shared, blocked=blocked) + self.integrator = integration.CpuLeapfrogIntegrator( + size, self.potential, self._logp_dlogp_func) diff --git a/pymc3/step_methods/hmc/hmc.py b/pymc3/step_methods/hmc/hmc.py index 05d7cb48f7..3d46cee771 100644 --- a/pymc3/step_methods/hmc/hmc.py +++ b/pymc3/step_methods/hmc/hmc.py @@ -58,13 +58,21 @@ def __init__(self, vars=None, path_length=2., step_rand=unif, **kwargs): def astep(self, q0): e = floatX(self.step_rand(self.step_size)) - n_steps = np.array(self.path_length / e, dtype='int32') - q = q0 - p = self.H.pot.random() # initialize momentum - initial_energy = self.compute_energy(q, p) - q, p, current_energy = self.leapfrog(q, p, e, n_steps) - energy_change = initial_energy - current_energy - return metrop_select(energy_change, q, q0)[0] + n_steps = int(self.path_length / e) + + p0 = self.potential.random() + start = self.integrator.compute_state(q0, p0) + + if not np.isfinite(start.energy): + raise ValueError('Bad initial energy: %s. The model ' + 'might be misspecified.' % start.energy) + + state = start + for _ in range(n_steps): + state = self.integrator.step(e, state) + + energy_change = start.energy - state.energy + return metrop_select(energy_change, state.q, start.q)[0] @staticmethod def competence(var): diff --git a/pymc3/step_methods/hmc/integration.py b/pymc3/step_methods/hmc/integration.py new file mode 100644 index 0000000000..04605071b1 --- /dev/null +++ b/pymc3/step_methods/hmc/integration.py @@ -0,0 +1,68 @@ +from collections import namedtuple + +import numpy as np +from scipy import linalg + + +State = namedtuple("State", 'q, p, v, q_grad, energy') + + +class CpuLeapfrogIntegrator(object): + def __init__(self, ndim, potential, logp_dlogp_func): + self._ndim = ndim + self._potential = potential + self._logp_dlogp_func = logp_dlogp_func + self._dtype = self._logp_dlogp_func.dtype + if self._potential.dtype != self._dtype: + raise ValueError("dtypes of potential and logp function " + "don't match." + % (self._potential.dtype, self._dtype)) + + def compute_state(self, q, p): + if q.dtype != self._dtype or p.dtype != self._dtype: + raise ValueError('Invalid dtype. Must be %s' % self._dtype) + logp, dlogp = self._logp_dlogp_func(q) + v = self._potential.velocity(p) + kinetic = self._potential.energy(p, velocity=v) + energy = kinetic - logp + return State(q, p, v, dlogp, energy) + + def step(self, epsilon, state, out=None): + pot = self._potential + axpy = linalg.blas.get_blas_funcs('axpy', dtype=self._dtype) + + q, p, v, q_grad, energy = state + if out is None: + q_new = q.copy() + p_new = p.copy() + v_new = np.empty_like(q) + q_new_grad = np.empty_like(q) + else: + q_new, p_new, v_new, q_new_grad, energy = out + q_new[:] = q + p_new[:] = p + + dt = 0.5 * epsilon + + # p is already stored in p_new + # p_new = p + dt * q_grad + axpy(q_grad, p_new, a=dt) + + pot.velocity(p_new, out=v_new) + # q is already stored in q_new + # q_new = q + epsilon * v_new + axpy(v_new, q_new, a=epsilon) + + logp = self._logp_dlogp_func(q_new, q_new_grad) + + # p_new = p_new + dt * q_new_grad + axpy(q_new_grad, p_new, a=dt) + + kinetic = pot.velocity_energy(p_new, v_new) + energy = kinetic - logp + + if out is not None: + out.energy = energy + return + else: + return State(q_new, p_new, v_new, q_new_grad, energy) diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index b2499fa8b3..73af656c8a 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -90,7 +90,7 @@ class NUTS(BaseHMC): def __init__(self, vars=None, Emax=1000, target_accept=0.8, gamma=0.05, k=0.75, t0=10, adapt_step_size=True, max_treedepth=10, on_error='summary', - adapt_mass_matrix=True, early_max_treedepth=8, + early_max_treedepth=8, **kwargs): R""" Parameters @@ -115,14 +115,11 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8, adapt_step_size : bool, default=True Whether step size adaptation should be enabled. If this is disabled, `k`, `t0`, `gamma` and `target_accept` are ignored. - adapt_mass_matrix : bool, default=True - Whether to adapt the mass matrix during tuning if the - potential supports tuning. max_treedepth : int, default=10 The maximum tree depth. Trajectories are stoped when this depth is reached. early_max_treedepth : int, default=8 - The maximum tree depth during tuning. + The maximum tree depth during the first 200 tuning samples. integrator : str, default "leapfrog" The integrator to use for the trajectories. One of "leapfrog", "two-stage" or "three-stage". The second two can increase @@ -153,8 +150,7 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8, This is usually achieved by setting the `tune` parameter if `pm.sample` to the desired number of tuning steps. """ - super(NUTS, self).__init__(vars, use_single_leapfrog=True, **kwargs) - + super(NUTS, self).__init__(vars, **kwargs) self.Emax = Emax self.target_accept = target_accept @@ -168,7 +164,6 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8, self.log_step_size_bar = 0 self.m = 1 self.adapt_step_size = adapt_step_size - self.adapt_mass_matrix = adapt_mass_matrix self.max_treedepth = max_treedepth self.early_max_treedepth = early_max_treedepth @@ -177,11 +172,11 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8, def astep(self, q0): p0 = self.potential.random() - v0 = self.compute_velocity(p0) - start_energy = self.compute_energy(q0, p0) - if not np.all(np.isfinite(start_energy)): + start = self.integrator.compute_state(q0, p0) + + if not np.isfinite(start.energy): raise ValueError('Bad initial energy: %s. The model ' - 'might be misspecified.' % start_energy) + 'might be misspecified.' % start.energy) if not self.adapt_step_size: step_size = self.step_size @@ -190,13 +185,12 @@ def astep(self, q0): else: step_size = np.exp(self.log_step_size_bar) - if self.tune: + if self.tune and self.m < 200: max_treedepth = self.early_max_treedepth else: max_treedepth = self.max_treedepth - start = Edge(q0, p0, v0, self.dlogp(q0), start_energy) - tree = _Tree(len(p0), self.leapfrog, start, step_size, self.Emax) + tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax) for _ in range(max_treedepth): direction = logbern(np.log(0.5)) * 2 - 1 @@ -219,7 +213,7 @@ def astep(self, q0): self.m += 1 - if self.tune and self.adapt_mass_matrix: + if self.tune: self.potential.adapt(q, q_grad) stats = { @@ -240,9 +234,6 @@ def competence(var): return Competence.INCOMPATIBLE -# A node in the NUTS tree that is at the far right or left of the tree -Edge = namedtuple("Edge", 'q, p, v, q_grad, energy') - # A proposal for the next position Proposal = namedtuple("Proposal", "q, q_grad, energy, p_accept") @@ -253,14 +244,14 @@ def competence(var): class _Tree(object): - def __init__(self, ndim, leapfrog, start, step_size, Emax): + def __init__(self, ndim, integrator, start, step_size, Emax): """Binary tree from the NUTS algorithm. Parameters ---------- leapfrog : function A function that performs a single leapfrog step. - start : Edge + start : integration.State The starting point of the trajectory. step_size : float The step size to use in this tree @@ -269,7 +260,7 @@ def __init__(self, ndim, leapfrog, start, step_size, Emax): transition as diverging. """ self.ndim = ndim - self.leapfrog = leapfrog + self.integrator = integrator self.start = start self.step_size = step_size self.Emax = Emax @@ -328,7 +319,7 @@ def extend(self, direction): def _single_step(self, left, epsilon): """Perform a leapfrog step and handle error cases.""" try: - right = self.leapfrog(left.q, left.p, left.q_grad, epsilon) + right = self.integrator.step(epsilon, left) except linalg.LinAlgError as err: error_msg = "LinAlgError during leapfrog step." error = err @@ -341,7 +332,6 @@ def _single_step(self, left, epsilon): else: raise else: - right = Edge(*right) energy_change = right.energy - self.start_energy if np.isnan(energy_change): energy_change = np.inf diff --git a/pymc3/step_methods/hmc/quadpotential.py b/pymc3/step_methods/hmc/quadpotential.py index 5e8802b7cd..2e01ccb9b9 100644 --- a/pymc3/step_methods/hmc/quadpotential.py +++ b/pymc3/step_methods/hmc/quadpotential.py @@ -2,8 +2,6 @@ from numpy.random import normal import scipy.linalg from scipy.sparse import issparse -import theano.tensor as tt -from theano.tensor import slinalg import theano from pymc3.theanof import floatX @@ -73,15 +71,18 @@ def __str__(self): class QuadPotential(object): - def velocity(self, x): + def velocity(self, x, out=None): raise NotImplementedError('Abstract method') - def energy(self, x): + def energy(self, x, velocity=None): raise NotImplementedError('Abstract method') def random(self, x): raise NotImplementedError('Abstract method') + def velocity_energy(self, x, v_out): + raise NotImplementedError('Abstract method') + def adapt(self, sample, grad): """Inform the potential about a new sample during tuning. @@ -116,26 +117,32 @@ def __init__(self, n, initial_mean, initial_diag=None, initial_weight=0, if dtype is None: dtype = theano.config.floatX - self._dtype = dtype + self.dtype = dtype self._n = n - self._var = np.array(initial_diag, dtype=self._dtype, copy=True) + self._var = np.array(initial_diag, dtype=self.dtype, copy=True) self._var_theano = theano.shared(self._var) self._stds = np.sqrt(initial_diag) self._inv_stds = floatX(1.) / self._stds self._foreground_var = _WeightedVariance( - self._n, initial_mean, initial_diag, initial_weight, self._dtype) - self._background_var = _WeightedVariance(self._n, dtype=self._dtype) + self._n, initial_mean, initial_diag, initial_weight, self.dtype) + self._background_var = _WeightedVariance(self._n, dtype=self.dtype) self._n_samples = 0 self.adaptation_window = adaptation_window - def velocity(self, x): - return self._var_theano * x + def velocity(self, x, out=None): + return np.multiply(self._var, x, out=out) + + def energy(self, x, velocity=None): + if velocity is not None: + return 0.5 * x.dot(velocity) + return 0.5 * x.dot(self._var * x) - def energy(self, x): - return 0.5 * x.dot(self._var_theano * x) + def velocity_energy(self, x, v_out): + self.velocity(x, out=v_out) + return 0.5 * np.dot(x, v_out) def random(self): - vals = floatX(normal(size=self._n)) + vals = normal(size=self._n).astype(self.dtype) return self._inv_stds * vals def _update_from_weightvar(self, weightvar): @@ -153,7 +160,7 @@ def adapt(self, sample, grad): if self._n_samples > 0 and self._n_samples % window == 0: self._foreground_var = self._background_var - self._background_var = _WeightedVariance(self._n, dtype=self._dtype) + self._background_var = _WeightedVariance(self._n, dtype=self.dtype) self._n_samples += 1 @@ -165,9 +172,9 @@ class QuadPotentialDiagAdaptGrad(QuadPotentialDiagAdapt): """ def __init__(self, *args, **kwargs): super(QuadPotentialDiagAdaptGrad, self).__init__(*args, **kwargs) - self._grads1 = np.zeros(self._n) + self._grads1 = np.zeros(self._n, dtype=self.dtype) self._ngrads1 = 0 - self._grads2 = np.zeros(self._n) + self._grads2 = np.zeros(self._n, dtype=self.dtype) self._ngrads2 = 0 def _update(self, var): @@ -241,58 +248,88 @@ def current_mean(self): class QuadPotentialDiag(QuadPotential): - def __init__(self, v): - v = floatX(v) + def __init__(self, v, dtype=None): + if dtype is None: + dtype = theano.config.floatX + self.dtype = dtype + v = v.astype(self.dtype) s = v ** .5 self.s = s self.inv_s = 1. / s self.v = v - def velocity(self, x): + def velocity(self, x, out=None): + if out is not None: + np.multiply(x, self.v, out=out) + return return self.v * x def random(self): return floatX(normal(size=self.s.shape)) * self.inv_s - def energy(self, x): + def energy(self, x, velocity=None): + if velocity is not None: + return 0.5 * np.dot(x, velocity) return .5 * x.dot(self.v * x) + def velocity_energy(self, x, v_out): + np.multiply(x, self.v, out=v_out) + return 0.5 * np.dot(x, v_out) + class QuadPotentialFullInv(QuadPotential): - def __init__(self, A): + def __init__(self, A, dtype=None): + if dtype is None: + dtype = theano.config.floatX + self.dtype = dtype self.L = floatX(scipy.linalg.cholesky(A, lower=True)) - def velocity(self, x): - solve = slinalg.Solve(lower=True) - y = solve(self.L, x) - return solve(self.L.T, y) + def velocity(self, x, out=None): + vel = scipy.linalg.cho_solve((self.L, True), x) + if out is None: + return vel + out[:] = vel def random(self): n = floatX(normal(size=self.L.shape[0])) return np.dot(self.L, n) - def energy(self, x): - L1x = slinalg.Solve(lower=True)(self.L, x) - return .5 * L1x.T.dot(L1x) + def energy(self, x, velocity=None): + if velocity is None: + velocity = self.velocity(x) + return .5 * x.dot(velocity) + + def velocity_energy(self, x, v_out): + self.velocity(x, out=v_out) + return 0.5 * np.dot(x, v_out) class QuadPotentialFull(QuadPotential): - def __init__(self, A): - self.A = floatX(A) + def __init__(self, A, dtype=None): + if dtype is None: + dtype = theano.config.floatX + self.dtype = dtype + self.A = A.astype(self.dtype) self.L = scipy.linalg.cholesky(A, lower=True) - def velocity(self, x): - return tt.dot(self.A, x) + def velocity(self, x, out=None): + return np.dot(self.A, x, out=out) def random(self): n = floatX(normal(size=self.L.shape[0])) return scipy.linalg.solve_triangular(self.L.T, n) - def energy(self, x): - return .5 * x.dot(self.A).dot(x) + def energy(self, x, velocity=None): + if velocity is None: + velocity = self.velocity(x) + return .5 * x.dot(velocity) + + def velocity_energy(self, x, v_out): + self.velocity(x, out=v_out) + return 0.5 * np.dot(x, v_out) __call__ = random diff --git a/pymc3/step_methods/hmc/trajectory.py b/pymc3/step_methods/hmc/trajectory.py deleted file mode 100644 index a7c730a9c5..0000000000 --- a/pymc3/step_methods/hmc/trajectory.py +++ /dev/null @@ -1,299 +0,0 @@ -from collections import namedtuple - -from pymc3.theanof import join_nonshared_inputs, gradient, CallableTensor, floatX - -import theano -import theano.tensor as tt -import numpy as np - - -Hamiltonian = namedtuple("Hamiltonian", "logp, dlogp, pot") - - -def _theano_hamiltonian(model_vars, shared, logpt, potential): - """Creates a Hamiltonian with shared inputs. - - Parameters - ---------- - model_vars : array of variables to be sampled - shared : theano tensors that are already shared - logpt : model log probability - potential : hamiltonian potential - - Returns - ------- - Hamiltonian : namedtuple with log pdf, gradient of log pdf, and potential functions - q : Starting position variable. - """ - dlogp = gradient(logpt, model_vars) - (logp, dlogp), q = join_nonshared_inputs([logpt, dlogp], model_vars, shared) - dlogp_func = theano.function(inputs=[q], outputs=dlogp) - dlogp_func.trust_input = True - logp = CallableTensor(logp) - dlogp = CallableTensor(dlogp) - return Hamiltonian(logp, dlogp, potential), q, dlogp_func - - -def _theano_energy_function(H, q, **theano_kwargs): - """Creates a Hamiltonian with shared inputs. - - Parameters - ---------- - H : Hamiltonian namedtuple - q : theano variable, starting position - theano_kwargs : passed to theano.function - - Returns - ------- - energy_function : theano function that computes the energy at a point (p, q) in phase space - p : Starting momentum variable. - """ - p = tt.vector('p') - p.tag.test_value = q.tag.test_value - total_energy = H.pot.energy(p) - H.logp(q) - energy_function = theano.function(inputs=[q, p], outputs=total_energy, **theano_kwargs) - energy_function.trust_input = True - - return energy_function, p - - -def _theano_velocity_function(H, p, **theano_kwargs): - v = H.pot.velocity(p) - velocity_function = theano.function(inputs=[p], outputs=v, **theano_kwargs) - velocity_function.trust_input = True - return velocity_function - - -def _theano_leapfrog_integrator(H, q, p, **theano_kwargs): - """Computes a theano function that computes one leapfrog step and the energy at the - end of the trajectory. - - Parameters - ---------- - H : Hamiltonian - q : theano.tensor - p : theano.tensor - theano_kwargs : passed to theano.function - - Returns - ------- - theano function which returns - q_new, p_new, energy_new - """ - epsilon = tt.scalar('epsilon') - epsilon.tag.test_value = 1. - - n_steps = tt.iscalar('n_steps') - n_steps.tag.test_value = 2 - - q_new, p_new = leapfrog(H, q, p, epsilon, n_steps) - energy_new = energy(H, q_new, p_new) - - f = theano.function([q, p, epsilon, n_steps], [q_new, p_new, energy_new], **theano_kwargs) - f.trust_input = True - return f - - -def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential, - use_single_leapfrog=False, - integrator="leapfrog", **theano_kwargs): - """Construct theano functions for the Hamiltonian, energy, and leapfrog integrator. - - Parameters - ---------- - model_vars : array of variables to be sampled - shared : theano tensors that are already shared - logpt : model log probability - potential : Hamiltonian potential - theano_kwargs : dictionary of keyword arguments to pass to theano functions - use_single_leapfrog : bool - if only 1 integration step is done at a time (as in NUTS), this - provides a ~2x speedup - integrator : str - Integration scheme to use. One of "leapfog", "two-stage", or - "three-stage". - - Returns - ------- - H : Hamiltonian namedtuple - energy_function : theano function computing energy at a point in phase space - leapfrog_integrator : theano function integrating the Hamiltonian from a point in phase space - theano_variables : dictionary of variables used in the computation graph which may be useful - """ - H, q, dlogp = _theano_hamiltonian(model_vars, shared, logpt, potential) - energy_function, p = _theano_energy_function(H, q, **theano_kwargs) - velocity_function = _theano_velocity_function(H, p, **theano_kwargs) - if use_single_leapfrog: - try: - _theano_integrator = INTEGRATORS_SINGLE[integrator] - except KeyError: - raise ValueError("Unknown integrator: %s" % integrator) - integrator = _theano_integrator(H, q, p, H.dlogp(q), **theano_kwargs) - else: - if integrator != "leapfrog": - raise ValueError("Only leapfrog is supported") - integrator = _theano_leapfrog_integrator(H, q, p, **theano_kwargs) - return H, energy_function, velocity_function, integrator, dlogp - - -def energy(H, q, p): - """Compute the total energy for the Hamiltonian at a given position/momentum""" - return H.pot.energy(p) - H.logp(q) - - -def leapfrog(H, q, p, epsilon, n_steps): - """Leapfrog integrator. - - Estimates `p(t)` and `q(t)` at time :math:`t = n \cdot e`, by integrating the - Hamiltonian equations - - .. math:: - - \frac{dq_i}{dt} = \frac{\partial H}{\partial p_i} - - \frac{dp_i}{dt} = \frac{\partial H}{\partial q_i} - - with :math:`p(0) = p`, :math:`q(0) = q` - - Parameters - ---------- - H : Hamiltonian instance. - Tuple of `logp, dlogp, potential`. - q : Theano.tensor - initial position vector - p : Theano.tensor - initial momentum vector - epsilon : float, step size - n_steps : int, number of iterations - - Returns - ------- - position : Theano.tensor - position estimate at time :math:`n \cdot e`. - momentum : Theano.tensor - momentum estimate at time :math:`n \cdot e`. - """ - def full_update(p, q): - p = p + epsilon * H.dlogp(q) - q += epsilon * H.pot.velocity(p) - return p, q - # This first line can't be +=, possibly because of theano - p = p + 0.5 * epsilon * H.dlogp(q) # half momentum update - q += epsilon * H.pot.velocity(p) # full position update - if tt.gt(n_steps, 1): - (p_seq, q_seq), _ = theano.scan(full_update, outputs_info=[p, q], n_steps=n_steps - 1) - p, q = p_seq[-1], q_seq[-1] - p += 0.5 * epsilon * H.dlogp(q) # half momentum update - return q, p - - -def _theano_single_threestage(H, q, p, q_grad, **theano_kwargs): - """Perform a single step of a third order symplectic integration scheme. - - References - ---------- - Blanes, Sergio, Fernando Casas, and J. M. Sanz-Serna. "Numerical - Integrators for the Hybrid Monte Carlo Method." SIAM Journal on - Scientific Computing 36, no. 4 (January 2014): A1556-80. - doi:10.1137/130932740. - - Mannseth, Janne, Tore Selland Kleppe, and Hans J. Skaug. "On the - Application of Higher Order Symplectic Integrators in - Hamiltonian Monte Carlo." arXiv:1608.07048 [Stat], - August 25, 2016. http://arxiv.org/abs/1608.07048. - """ - epsilon = tt.scalar('epsilon') - epsilon.tag.test_value = 1. - - a = 12127897.0 / 102017882 - b = 4271554.0 / 14421423 - - # q_{a\epsilon} - p_ae = p + floatX(a) * epsilon * q_grad - q_be = q + floatX(b) * epsilon * H.pot.velocity(p_ae) - - # q_{\epsilon / 2} - p_e2 = p_ae + floatX(0.5 - a) * epsilon * H.dlogp(q_be) - - # p_{(1-b)\epsilon} - q_1be = q_be + floatX(1 - 2 * b) * epsilon * H.pot.velocity(p_e2) - p_1ae = p_e2 + floatX(0.5 - a) * epsilon * H.dlogp(q_1be) - - q_e = q_1be + floatX(b) * epsilon * H.pot.velocity(p_1ae) - grad_e = H.dlogp(q_e) - p_e = p_1ae + floatX(a) * epsilon * grad_e - v_e = H.pot.velocity(p_e) - - new_energy = energy(H, q_e, p_e) - - f = theano.function(inputs=[q, p, q_grad, epsilon], - outputs=[q_e, p_e, v_e, grad_e, new_energy], - **theano_kwargs) - f.trust_input = True - return f - - -def _theano_single_twostage(H, q, p, q_grad, **theano_kwargs): - """Perform a single step of a second order symplectic integration scheme. - - References - ---------- - Blanes, Sergio, Fernando Casas, and J. M. Sanz-Serna. "Numerical - Integrators for the Hybrid Monte Carlo Method." SIAM Journal on - Scientific Computing 36, no. 4 (January 2014): A1556-80. - doi:10.1137/130932740. - - Mannseth, Janne, Tore Selland Kleppe, and Hans J. Skaug. "On the - Application of Higher Order Symplectic Integrators in - Hamiltonian Monte Carlo." arXiv:1608.07048 [Stat], - August 25, 2016. http://arxiv.org/abs/1608.07048. - """ - epsilon = tt.scalar('epsilon') - epsilon.tag.test_value = 1. - - a = floatX((3 - np.sqrt(3)) / 6) - - p_ae = p + a * epsilon * q_grad - q_e2 = q + epsilon / 2 * H.pot.velocity(p_ae) - p_1ae = p_ae + (1 - 2 * a) * epsilon * H.dlogp(q_e2) - q_e = q_e2 + epsilon / 2 * H.pot.velocity(p_1ae) - grad_e = H.dlogp(q_e) - p_e = p_1ae + a * epsilon * grad_e - v_e = H.pot.velocity(p_e) - - new_energy = energy(H, q_e, p_e) - f = theano.function(inputs=[q, p, q_grad, epsilon], - outputs=[q_e, p_e, v_e, grad_e, new_energy], - **theano_kwargs) - f.trust_input = True - return f - - -def _theano_single_leapfrog(H, q, p, q_grad, **theano_kwargs): - """Leapfrog integrator for a single step. - - See above for documentation. This is optimized for the case where only a single step is - needed, in case of, for example, a recursive algorithm. - """ - epsilon = tt.scalar('epsilon') - epsilon.tag.test_value = 1. - - p_new = p + 0.5 * epsilon * q_grad # half momentum update - q_new = q + epsilon * H.pot.velocity(p_new) # full position update - q_new_grad = H.dlogp(q_new) - p_new += 0.5 * epsilon * q_new_grad # half momentum update - energy_new = energy(H, q_new, p_new) - v_new = H.pot.velocity(p_new) - - f = theano.function(inputs=[q, p, q_grad, epsilon], - outputs=[q_new, p_new, v_new, q_new_grad, energy_new], - **theano_kwargs) - f.trust_input = True - return f - - -INTEGRATORS_SINGLE = { - 'leapfrog': _theano_single_leapfrog, - 'two-stage': _theano_single_twostage, - 'three-stage': _theano_single_threestage, -} diff --git a/pymc3/step_methods/smc.py b/pymc3/step_methods/smc.py index 0c39d4177b..54275925c3 100644 --- a/pymc3/step_methods/smc.py +++ b/pymc3/step_methods/smc.py @@ -324,7 +324,7 @@ def select_end_points(self, mtrace): likelihoods : :class:`numpy.ndarray` Array of likelihoods of the trace end-points """ - array_population = np.zeros((self.n_chains, self.ordering.dimensions)) + array_population = np.zeros((self.n_chains, self.ordering.size)) n_steps = len(mtrace) # collect end points of each chain and put into array @@ -357,7 +357,7 @@ def get_chain_previous_lpoint(self, mtrace): chain_previous_lpoint : list all unobservedRV values, including dataset likelihoods """ - array_population = np.zeros((self.n_chains, self.lordering.dimensions)) + array_population = np.zeros((self.n_chains, self.lordering.size)) n_steps = len(mtrace) for _, slc, shp, _, var in self.lordering.vmap: slc_population = mtrace.get_values(varname=var, burn=n_steps - 1, combine=True) diff --git a/pymc3/tests/models.py b/pymc3/tests/models.py index 5e91711595..a545424d62 100644 --- a/pymc3/tests/models.py +++ b/pymc3/tests/models.py @@ -69,6 +69,16 @@ def simple_2model(): return model.test_point, model +def simple_2model_continuous(): + mu = -2.1 + tau = 1.3 + with Model() as model: + x = pm.Normal('x', mu, tau=tau, testval=.1) + pm.Deterministic('logx', tt.log(x)) + pm.Beta('y', alpha=1, beta=1, shape=2) + return model.test_point, model + + def mv_simple(): mu = floatX_array([-.1, .5, 1.1]) p = floatX_array([ diff --git a/pymc3/tests/test_hmc.py b/pymc3/tests/test_hmc.py index fcdc8cd176..72888bb4f1 100644 --- a/pymc3/tests/test_hmc.py +++ b/pymc3/tests/test_hmc.py @@ -1,62 +1,32 @@ import numpy as np +import numpy.testing as npt -from pymc3.blocking import DictToArrayBijection from . import models from pymc3.step_methods.hmc.base_hmc import BaseHMC import pymc3 from pymc3.theanof import floatX -from .checks import close_to -from .helpers import select_by_precision -import pytest -import theano def test_leapfrog_reversible(): n = 3 + np.random.seed(42) start, model, _ = models.non_normal(n) - step = BaseHMC(vars=model.vars, model=model) - bij = DictToArrayBijection(step.ordering, start) - q0 = bij.map(start) - p0 = floatX(np.ones(n) * .05) - precision = select_by_precision(float64=1E-8, float32=1E-4) - for epsilon in [.01, .1, 1.2]: + size = model.ndim + scaling = floatX(np.random.rand(size)) + step = BaseHMC(vars=model.vars, model=model, scaling=scaling) + step.integrator._logp_dlogp_func.set_extra_values({}) + p = floatX(step.potential.random()) + q = floatX(np.random.randn(size)) + start = step.integrator.compute_state(p, q) + for epsilon in [.01, .1]: for n_steps in [1, 2, 3, 4, 20]: - - q, p = q0, p0 - q, p, _ = step.leapfrog(q, p, floatX(np.array(epsilon)), np.array(n_steps, dtype='int32')) - q, p, _ = step.leapfrog(q, -p, floatX(np.array(epsilon)), np.array(n_steps, dtype='int32')) - close_to(q, q0, precision, str((n_steps, epsilon))) - close_to(-p, p0, precision, str((n_steps, epsilon))) - -@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") -def test_leapfrog_reversible_single(): - n = 3 - start, model, _ = models.non_normal(n) - - integrators = ['leapfrog', 'two-stage', 'three-stage'] - steps = [BaseHMC(vars=model.vars, model=model, integrator=method, use_single_leapfrog=True) - for method in integrators] - for method, step in zip(integrators, steps): - bij = DictToArrayBijection(step.ordering, start) - q0 = bij.map(start) - p0 = floatX(np.ones(n) * .05) - precision = select_by_precision(float64=1E-8, float32=1E-5) - for epsilon in [0.01, 0.1, 1.2]: - for n_steps in [1, 2, 3, 4, 20]: - dlogp0 = step.dlogp(q0) - - q, p = q0, p0 - dlogp = dlogp0 - - energy = step.compute_energy(q, p) - for _ in range(n_steps): - q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, floatX(np.array(epsilon))) - p = -p - for _ in range(n_steps): - q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, floatX(np.array(epsilon))) - - close_to(q, q0, precision, str(('q', method, n_steps, epsilon))) - close_to(-p, p0, precision, str(('p', method, n_steps, epsilon))) + state = start + for _ in range(n_steps): + state = step.integrator.step(epsilon, state) + for _ in range(n_steps): + state = step.integrator.step(-epsilon, state) + npt.assert_allclose(state.q, start.q, rtol=1e-5) + npt.assert_allclose(state.p, start.p, rtol=1e-5) def test_nuts_tuning(): diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index e25df56b77..6c0df289a1 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -3,10 +3,12 @@ import numpy as np import pandas as pd import numpy.testing as npt +import unittest import pymc3 as pm from pymc3.distributions import HalfCauchy, Normal, transforms from pymc3 import Potential, Deterministic +from pymc3.model import ValueGradFunction class NewModel(pm.Model): @@ -145,6 +147,7 @@ def test_nested(self): assert theano.config.compute_test_value == 'ignore' assert theano.config.compute_test_value == 'off' + def test_duplicate_vars(): with pytest.raises(ValueError) as err: with pm.Model(): @@ -179,3 +182,80 @@ def test_empty_observed(): npt.assert_allclose(a.tag.test_value, np.zeros((2, 3))) b = pm.Beta('b', alpha=1, beta=1, observed=data) npt.assert_allclose(b.tag.test_value, np.ones((2, 3)) / 2) + + +class TestValueGradFunction(unittest.TestCase): + def test_no_extra(self): + a = tt.vector('a') + a.tag.test_value = np.zeros(3, dtype=a.dtype) + a.dshape = (3,) + a.dsize = 3 + f_grad = ValueGradFunction(a.sum(), [a], [], mode='FAST_COMPILE') + assert f_grad.size == 3 + + def test_invalid_type(self): + a = tt.ivector('a') + a.tag.test_value = np.zeros(3, dtype=a.dtype) + a.dshape = (3,) + a.dsize = 3 + with pytest.raises(TypeError) as err: + ValueGradFunction(a.sum(), [a], [], mode='FAST_COMPILE') + err.match('Invalid dtype') + + def setUp(self): + extra1 = tt.iscalar('extra1') + extra1_ = np.array(0, dtype=extra1.dtype) + extra1.tag.test_value = extra1_ + extra1.dshape = tuple() + extra1.dsize = 1 + + val1 = tt.vector('val1') + val1_ = np.zeros(3, dtype=val1.dtype) + val1.tag.test_value = val1_ + val1.dshape = (3,) + val1.dsize = 3 + + val2 = tt.matrix('val2') + val2_ = np.zeros((2, 3), dtype=val2.dtype) + val2.tag.test_value = val2_ + val2.dshape = (2, 3) + val2.dsize = 6 + + self.val1, self.val1_ = val1, val1_ + self.val2, self.val2_ = val2, val2_ + self.extra1, self.extra1_ = extra1, extra1_ + + self.cost = extra1 * val1.sum() + val2.sum() + + self.f_grad = ValueGradFunction( + self.cost, [val1, val2], [extra1], mode='FAST_COMPILE') + + def test_extra_not_set(self): + with pytest.raises(ValueError) as err: + self.f_grad.get_extra_values() + err.match('Extra values are not set') + + with pytest.raises(ValueError) as err: + self.f_grad(np.zeros(self.f_grad.size, dtype=self.f_grad.dtype)) + err.match('Extra values are not set') + + def test_grad(self): + self.f_grad.set_extra_values({'extra1': 5}) + array = np.ones(self.f_grad.size, dtype=self.f_grad.dtype) + val, grad = self.f_grad(array) + assert val == 21 + npt.assert_allclose(grad, [5, 5, 5, 1, 1, 1, 1, 1, 1]) + + def test_bij(self): + self.f_grad.set_extra_values({'extra1': 5}) + array = np.ones(self.f_grad.size, dtype=self.f_grad.dtype) + point = self.f_grad.array_to_dict(array) + assert len(point) == 2 + npt.assert_allclose(point['val1'], 1) + npt.assert_allclose(point['val2'], 1) + + array2 = self.f_grad.dict_to_array(point) + npt.assert_allclose(array2, array) + point_ = self.f_grad.array_to_full_dict(array) + assert len(point_) == 3 + assert point_['extra1'] == 5 diff --git a/pymc3/tests/test_quadpotential.py b/pymc3/tests/test_quadpotential.py index 537bd4acb0..235837d132 100644 --- a/pymc3/tests/test_quadpotential.py +++ b/pymc3/tests/test_quadpotential.py @@ -1,85 +1,60 @@ import numpy as np -import numpy.testing as npt import scipy.sparse -import theano.tensor as tt -import theano from pymc3.step_methods.hmc import quadpotential import pymc3 from pymc3.theanof import floatX import pytest +import numpy.testing as npt -@pytest.mark.skip() def test_elemwise_posdef(): scaling = np.array([0, 2, 3]) with pytest.raises(quadpotential.PositiveDefiniteError): - quadpotential.quad_potential(scaling, True, True) + quadpotential.quad_potential(scaling, True) -@pytest.mark.skip() -def test_elemwise_posdef2(): - scaling = np.array([0, 2, 3]) - with pytest.raises(quadpotential.PositiveDefiniteError): - quadpotential.quad_potential(scaling, True, False) - - -@pytest.mark.skip() def test_elemwise_velocity(): scaling = np.array([1, 2, 3]) - x_ = floatX(np.ones_like(scaling)) - x = tt.vector() - x.tag.test_value = x_ - pot = quadpotential.quad_potential(scaling, True, False) - v = theano.function([x], pot.velocity(x)) - assert np.allclose(v(x_), scaling) - pot = quadpotential.quad_potential(scaling, True, True) - v = theano.function([x], pot.velocity(x)) - assert np.allclose(v(x_), 1. / scaling) - - -@pytest.mark.skip() + x = floatX(np.ones_like(scaling)) + pot = quadpotential.quad_potential(scaling, True) + v = pot.velocity(x) + npt.assert_allclose(v, scaling) + assert v.dtype == pot.dtype + + def test_elemwise_energy(): scaling = np.array([1, 2, 3]) - x_ = floatX(np.ones_like(scaling)) - x = tt.vector() - x.tag.test_value = x_ - pot = quadpotential.quad_potential(scaling, True, False) - energy = theano.function([x], pot.energy(x)) - assert np.allclose(energy(x_), 0.5 * scaling.sum()) - pot = quadpotential.quad_potential(scaling, True, True) - energy = theano.function([x], pot.energy(x)) - assert np.allclose(energy(x_), 0.5 * (1. / scaling).sum()) - - -@pytest.mark.skip() + x = floatX(np.ones_like(scaling)) + pot = quadpotential.quad_potential(scaling, True) + energy = pot.energy(x) + npt.assert_allclose(energy, 0.5 * scaling.sum()) + + def test_equal_diag(): np.random.seed(42) for _ in range(3): diag = np.random.rand(5) - x_ = floatX(np.random.randn(5)) - x = tt.vector() - x.tag.test_value = x_ + x = floatX(np.random.randn(5)) pots = [ - quadpotential.quad_potential(diag, False, False), - quadpotential.quad_potential(1. / diag, True, False), - quadpotential.quad_potential(np.diag(diag), False, False), - quadpotential.quad_potential(np.diag(1. / diag), True, False), + quadpotential.quad_potential(diag, False), + quadpotential.quad_potential(1. / diag, True), + quadpotential.quad_potential(np.diag(diag), False), + quadpotential.quad_potential(np.diag(1. / diag), True), ] if quadpotential.chol_available: diag_ = scipy.sparse.csc_matrix(np.diag(1. / diag)) - pots.append(quadpotential.quad_potential(diag_, True, False)) + pots.append(quadpotential.quad_potential(diag_, True)) - v = np.diag(1. / diag).dot(x_) - e = x_.dot(np.diag(1. / diag).dot(x_)) / 2 + v = np.diag(1. / diag).dot(x) + e = x.dot(np.diag(1. / diag).dot(x)) / 2 for pot in pots: - v_function = theano.function([x], pot.velocity(x)) - e_function = theano.function([x], pot.energy(x)) - assert np.allclose(v_function(x_), v) - assert np.allclose(e_function(x_), e) + v_ = pot.velocity(x) + e_ = pot.energy(x) + npt.assert_allclose(v_, v) + npt.assert_allclose(e_, e) -@pytest.mark.skip() def test_equal_dense(): np.random.seed(42) for _ in range(3): @@ -87,46 +62,42 @@ def test_equal_dense(): cov += cov.T cov += 10 * np.eye(5) inv = np.linalg.inv(cov) - assert np.allclose(inv.dot(cov), np.eye(5)) - x_ = floatX(np.random.randn(5)) - x = tt.vector() - x.tag.test_value = x_ + npt.assert_allclose(inv.dot(cov), np.eye(5), atol=1e-10) + x = floatX(np.random.randn(5)) pots = [ - quadpotential.quad_potential(cov, False, False), - quadpotential.quad_potential(inv, True, False), + quadpotential.quad_potential(cov, False), + quadpotential.quad_potential(inv, True), ] if quadpotential.chol_available: - pots.append(quadpotential.quad_potential(cov, False, False)) + pots.append(quadpotential.quad_potential(cov, False)) - v = np.linalg.solve(cov, x_) - e = 0.5 * x_.dot(v) + v = np.linalg.solve(cov, x) + e = 0.5 * x.dot(v) for pot in pots: - v_function = theano.function([x], pot.velocity(x)) - e_function = theano.function([x], pot.energy(x)) - assert np.allclose(v_function(x_), v) - assert np.allclose(e_function(x_), e) + v_ = pot.velocity(x) + e_ = pot.energy(x) + npt.assert_allclose(v_, v, rtol=1e-4) + npt.assert_allclose(e_, e, rtol=1e-4) -@pytest.mark.skip() def test_random_diag(): d = np.arange(10) + 1 np.random.seed(42) pots = [ - quadpotential.quad_potential(d, True, False), - quadpotential.quad_potential(1./d, False, False), - quadpotential.quad_potential(np.diag(d), True, False), - quadpotential.quad_potential(np.diag(1./d), False, False), + quadpotential.quad_potential(d, True), + quadpotential.quad_potential(1./d, False), + quadpotential.quad_potential(np.diag(d), True), + quadpotential.quad_potential(np.diag(1./d), False), ] if quadpotential.chol_available: d_ = scipy.sparse.csc_matrix(np.diag(d)) - pot = quadpotential.quad_potential(d_, True, False) + pot = quadpotential.quad_potential(d_, True) pots.append(pot) for pot in pots: vals = np.array([pot.random() for _ in range(1000)]) - assert np.allclose(vals.std(0), np.sqrt(1./d), atol=0.1) + npt.assert_allclose(vals.std(0), np.sqrt(1./d), atol=0.1) -@pytest.mark.skip() def test_random_dense(): np.random.seed(42) for _ in range(3): @@ -137,8 +108,8 @@ def test_random_dense(): assert np.allclose(inv.dot(cov), np.eye(5)) pots = [ - quadpotential.QuadPotential(cov), - quadpotential.QuadPotential_Inv(inv), + quadpotential.QuadPotentialFull(cov), + quadpotential.QuadPotentialFullInv(inv), ] if quadpotential.chol_available: pot = quadpotential.QuadPotential_Sparse(scipy.sparse.csc_matrix(cov)) @@ -148,7 +119,6 @@ def test_random_dense(): assert np.allclose(cov_, inv, atol=0.1) -@pytest.mark.skip() def test_user_potential(): model = pymc3.Model() with model: @@ -157,34 +127,13 @@ def test_user_potential(): # Work around missing nonlocal in python2 called = [] - class Potential(quadpotential.ElemWiseQuadPotential): - def energy(self, x): + class Potential(quadpotential.QuadPotentialDiag): + def energy(self, x, velocity=None): called.append(1) - return super(Potential, self).energy(x) + return super(Potential, self).energy(x, velocity) - pot = Potential([1]) + pot = Potential(floatX([1])) with model: step = pymc3.NUTS(potential=pot) pymc3.sample(10, init=None, step=step) assert called - - -class TestWeightedVariance(object): - def test_no_init(self): - var = quadpotential._WeightedVariance(3) - with pytest.raises(ValueError) as err: - var.current_variance() - err.match('without samples') - - var.add_sample([0, 0, 0], 1) - npt.assert_allclose(var.current_variance(), [0, 0, 0]) - var.add_sample([1, 1, 1], 1) - npt.assert_allclose(var.current_variance(), 0.25) - var.add_sample([-1, 0, 1], 2) - npt.assert_allclose(var.current_variance(), [0.6875, 0.1875, 0.1875]) - - def test_with_init(self): - var = quadpotential._WeightedVariance(3, [0.5, 0.5, 0.5], [0.25, 0.25, 0.25], 2) - npt.assert_allclose(var.current_variance(), 0.25) - var.add_sample([-1, 0, 1], 2) - npt.assert_allclose(var.current_variance(), [0.6875, 0.1875, 0.1875]) diff --git a/pymc3/tests/test_stats.py b/pymc3/tests/test_stats.py index 23d698bd4b..687a685f23 100644 --- a/pymc3/tests/test_stats.py +++ b/pymc3/tests/test_stats.py @@ -17,7 +17,7 @@ def test_log_post_trace(): with pm.Model() as model: pm.Normal('y') - trace = pm.sample() + trace = pm.sample(10, tune=10) logp = pmstats._log_post_trace(trace, model) assert logp.shape == (len(trace), 0) @@ -25,7 +25,7 @@ def test_log_post_trace(): with pm.Model() as model: pm.Normal('a') pm.Normal('y', observed=np.zeros((2, 3))) - trace = pm.sample() + trace = pm.sample(10, tune=10) logp = pmstats._log_post_trace(trace, model) assert logp.shape == (len(trace), 6) @@ -40,7 +40,7 @@ def test_log_post_trace(): data = data.copy() data.values[:] = np.nan pm.Normal('y3', observed=data) - trace = pm.sample() + trace = pm.sample(10, tune=10) logp = pmstats._log_post_trace(trace, model) assert logp.shape == (len(trace), 17) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index 0d54380872..e0f0c0b6b2 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -1,9 +1,9 @@ import shutil import tempfile -import warnings from .checks import close_to -from .models import simple_categorical, mv_simple, mv_simple_discrete, simple_2model, mv_prior_simple +from .models import (simple_categorical, mv_simple, mv_simple_discrete, + mv_prior_simple, simple_2model_continuous) from pymc3.sampling import assign_step_methods, sample from pymc3.model import Model from pymc3.step_methods import (NUTS, BinaryGibbsMetropolis, CategoricalGibbsMetropolis, @@ -287,7 +287,7 @@ class TestCompoundStep(object): reason="Test fails on 32 bit due to linalg issues") def test_non_blocked(self): """Test that samplers correctly create non-blocked compound steps.""" - _, model = simple_2model() + _, model = simple_2model_continuous() with model: for sampler in self.samplers: assert isinstance(sampler(blocked=False), CompoundStep) @@ -295,7 +295,7 @@ def test_non_blocked(self): @pytest.mark.skipif(theano.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues") def test_blocked(self): - _, model = simple_2model() + _, model = simple_2model_continuous() with model: for sampler in self.samplers: sampler_instance = sampler(blocked=True) @@ -341,9 +341,9 @@ def test_binomial(self): class TestNutsCheckTrace(object): def test_multiple_samplers(self): with Model(): - prob = Beta('prob', alpha=5, beta=3) + prob = Beta('prob', alpha=5., beta=3.) Binomial('outcome', n=1, p=prob) - with warnings.catch_warnings(record=True) as warns: + with pytest.warns(None) as warns: sample(3, tune=2, discard_tuned_samples=False, n_init=None) messages = [warn.message.args[0] for warn in warns] @@ -363,7 +363,7 @@ def test_linalg(self): a = tt.switch(a > 0, np.inf, a) b = tt.slinalg.solve(floatX(np.eye(2)), a) Normal('c', mu=b, shape=2) - with warnings.catch_warnings(record=True) as warns: + with pytest.warns(None) as warns: trace = sample(20, init=None, tune=5) assert np.any(trace['diverging']) assert any('diverging samples after tuning' in str(warn.message) diff --git a/pymc3/tests/test_types.py b/pymc3/tests/test_types.py index ed606fc437..e59fce37c8 100644 --- a/pymc3/tests/test_types.py +++ b/pymc3/tests/test_types.py @@ -6,6 +6,7 @@ from pymc3.model import Model from pymc3.step_methods import NUTS, Metropolis, Slice, HamiltonianMC from pymc3.distributions import Normal +from pymc3.theanof import change_flags import numpy as np @@ -21,10 +22,8 @@ def teardown_method(self): # restore theano config theano.config = self.theano_config + @change_flags({'floatX': 'float64', 'warn_float64': 'ignore'}) def test_float64(self): - theano.config.floatX = 'float64' - theano.config.warn_float64 = 'ignore' - with Model() as model: x = Normal('x', testval=np.array(1., dtype='float64')) obs = Normal('obs', mu=x, sd=1., observed=np.random.randn(5)) @@ -36,10 +35,8 @@ def test_float64(self): with model: sample(10, sampler()) + @change_flags({'floatX': 'float32', 'warn_float64': 'warn'}) def test_float32(self): - theano.config.floatX = 'float32' - theano.config.warn_float64 = 'warn' - with Model() as model: x = Normal('x', testval=np.array(1., dtype='float32')) obs = Normal('obs', mu=x, sd=1., observed=np.random.randn(5).astype('float32')) diff --git a/pymc3/variational/opvi.py b/pymc3/variational/opvi.py index 48610944aa..0395e564ed 100644 --- a/pymc3/variational/opvi.py +++ b/pymc3/variational/opvi.py @@ -1169,8 +1169,8 @@ def total_size(self): @property def local_size(self): - return self._l_order.dimensions + return self._l_order.size @property def global_size(self): - return self._g_order.dimensions + return self._g_order.size