diff --git a/pymc3_hmm/step_methods.py b/pymc3_hmm/step_methods.py index d7c470c..13a156b 100644 --- a/pymc3_hmm/step_methods.py +++ b/pymc3_hmm/step_methods.py @@ -28,9 +28,11 @@ from theano.tensor.var import TensorConstant import pymc3 as pm +import scipy from pymc3.distributions.distribution import draw_values from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence from pymc3.util import get_untransformed_name +from scipy.stats import invgamma from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess from pymc3_hmm.utils import compute_trans_freqs @@ -452,3 +454,102 @@ def competence(var): return Competence.COMPATIBLE return Competence.INCOMPATIBLE + + +def large_p_mvnormal_sampler(D_diag, Phi, a): + r"""Efficiently sample from a large multivariate normal. + + This function draws samples from the following distribution: + + .. math:: + \beta \sim \operatorname{N}\left( \mu, \Sigma \right) + + where + + .. math:: + \mu = \Sigma \Phi^\top a, \\ + \Sigma = \left( \Phi^\top \Phi + D^{-1} \right)^{-1} + + and :math:`a \in \mathbb{R}^{n}`, :math:`\Phi \in \mathbb{R}^{n \times p}`. + + This approach is particularly effective when :math:`p \gg n`. + + From "Fast sampling with Gaussian scale-mixture priors in high-dimensional + regression", Bhattacharya, Chakraborty, and Mallick, 2015. + + """ + N = a.shape[0] + u = np.random.normal(0, np.sqrt(D_diag)) + delta = np.random.normal(size=N) + if not (type(Phi) is scipy.sparse.csr.csr_matrix): + Phi_D = Phi * D_diag + v = Phi @ u + delta + Z = Phi_D @ Phi.T + np.eye(N) + w = scipy.linalg.solve(Z, a - v, assume_a="sym") + beta = u + Phi_D.T @ w + + else: + Phi_D = Phi.multiply(D_diag) + v = Phi * u + delta + Z = Phi_D * Phi.T + np.eye(N) + w = scipy.linalg.solve(Z, a - v, assume_a="sym") + beta = u + Phi_D.T * w + + return beta + + +def hs_step( + lambda2: np.ndarray, + tau2: np.ndarray, + vi: np.ndarray, + xi: np.ndarray, + X: np.ndarray, + y: np.ndarray, +): + N, M = X.shape + + D_diag = tau2 * lambda2 + beta = large_p_mvnormal_sampler(D_diag, X, y) + beta2 = beta ** 2 + + lambda2 = invgamma(a=1, scale=1 / vi + beta2 / (2 * tau2)).rvs() + tau2 = invgamma(a=(M + 1) / 2, scale=1 / xi + (beta2 / lambda2).sum() / 2).rvs() + vi = invgamma(a=1, scale=1 + 1 / lambda2).rvs() + xi = invgamma(a=1, scale=1 + 1 / tau2).rvs() + + return beta, lambda2, tau2, vi, xi + + +class HSStep(BlockedStep): + name = "hsgibbs" + + def __init__(self, vars, y, X, values=None, model=None): + + if len(vars) > 1: + raise ValueError("This sampler only takes one variable.") + + (var,) = pm.inputvars(vars) + + if not isinstance(var.distribution, pm.distributions.Normal): + raise TypeError("This sampler only samples `Normal`s.") + + model = pm.modelcontext(model) + + self.vars = [var] + + M = model.test_point[var.name].shape[-1] + + self.vi = np.full(M, 1) + self.lambda2 = np.full(M, 1) + self.beta = np.full(M, 1) + self.tau2 = 1 + self.xi = 1 + self.y = y + self.X = X + + def step(self, point): + self.beta, self.lambda2, self.tau2, self.vi, self.xi = hs_step( + self.lambda2, self.tau2, self.vi, self.xi, self.X, self.y + ) + point[self.vars[0].name] = self.beta + return point diff --git a/tests/test_step_methods.py b/tests/test_step_methods.py index bdf05fa..22e2977 100644 --- a/tests/test_step_methods.py +++ b/tests/test_step_methods.py @@ -14,7 +14,14 @@ import scipy as sp from pymc3_hmm.distributions import DiscreteMarkovChain, PoissonZeroProcess -from pymc3_hmm.step_methods import FFBSStep, TransMatConjugateStep, ffbs_step +from pymc3_hmm.step_methods import ( + FFBSStep, + HSStep, + TransMatConjugateStep, + ffbs_step, + hs_step, + large_p_mvnormal_sampler, +) from pymc3_hmm.utils import compute_steady_state, compute_trans_freqs from tests.utils import simulate_poiszero_hmm @@ -339,3 +346,83 @@ def test_TransMatConjugateStep_subtensors(): ) transmat = TransMatConjugateStep(P_rv) + + +def test_large_p_mvnormal_sampler(): + + # test case for dense matrix + np.random.seed(2032) + X = np.random.choice([0, 1], size=25).reshape((5, 5)) + y = np.array(range(5)) + + samples = large_p_mvnormal_sampler(np.ones(5), X, y) + assert samples.shape == (5,) + + # test case for sparse matrix + samples_sp = large_p_mvnormal_sampler(np.ones(5), sp.sparse.csr_matrix(X), y) + assert samples_sp.shape == (5,) + + +def test_hs_step(): + # test case for dense matrix + np.random.seed(2032) + M = 5 + X = np.random.choice([0, 1], size=25).reshape((M, M)) + beta_true = np.random.normal(size=M) + y = np.random.normal(X.dot(beta_true), 1) + + vi = np.full(M, 1) + lambda2 = np.full(M, 1) + tau2 = 1 + xi = 1 + beta, lambda2, tau2, vi, xi = hs_step(lambda2, tau2, vi, xi, X, y) + assert beta.shape == beta_true.shape + assert (np.abs(beta - beta_true) / beta_true).mean() < 0.5 + + # test case for sparse matrix + + vi = np.full(M, 1) + lambda2 = np.full(M, 1) + tau2 = 1 + xi = 1 + beta, lambda2, tau2, vi, xi = hs_step( + lambda2, tau2, vi, xi, sp.sparse.csr_matrix(X), y + ) + assert beta.shape == beta_true.shape + assert (np.abs(beta - beta_true) / beta_true).mean() < 0.5 + + +def test_Hsstep(): + np.random.seed(2032) + M = 5 + X = np.random.choice([0, 1], size=25).reshape((M, M)) + beta_true = np.random.normal(size=M) + y = np.random.normal(X.dot(beta_true), 1) + + M = X.shape[1] + with pm.Model() as _: + beta = pm.Normal("beta", 0, 1, shape=M) + hsstep = HSStep([beta], y, X) + trace = pm.sample( + draws=20, tune=0, step=hsstep, chains=1, return_inferencedata=True + ) + + beta_samples = trace.posterior["beta"][0].values + + assert beta_samples.shape == (20, M) + + # test case for sparse matrix + X = np.random.choice([0, 1], size=25).reshape((M, M)) + X_sp = sp.sparse.csr_matrix(X) + + M = X.shape[1] + with pm.Model() as _: + beta = pm.Normal("beta", 0, 1, shape=M) + hsstep = HSStep([beta], y, X_sp) + trace = pm.sample( + draws=20, tune=0, step=hsstep, chains=1, return_inferencedata=True + ) + + beta_samples = trace.posterior["beta"][0].values + + assert beta_samples.shape == (20, M)