diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 3366d955938..091850b68f5 100644 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -47,6 +47,7 @@ from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln from pymc3.distributions.distribution import Continuous, Discrete +from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple from pymc3.math import kron_diag, kron_dot __all__ = [ @@ -739,6 +740,25 @@ def __str__(self): matrix_pos_def = PosDefMatrix() +class WishartRV(RandomVariable): + name = "wishart" + ndim_supp = 2 + ndims_params = [0, 2] + dtype = "floatX" + _print_name = ("Wishart", "\\operatorname{Wishart}") + + def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): + # The shape of second parameter `V` defines the shape of the output. + return dist_params[1].shape + + @classmethod + def rng_fn(cls, rng, nu, V, size=None): + return stats.wishart.rvs(np.int(nu), V, size=size, random_state=rng) + + +wishart = WishartRV() + + class Wishart(Continuous): r""" Wishart log-likelihood. @@ -775,9 +795,10 @@ class Wishart(Continuous): This distribution is unusable in a PyMC3 model. You should instead use LKJCholeskyCov or LKJCorr. """ + rv_op = wishart - def __init__(self, nu, V, *args, **kwargs): - super().__init__(*args, **kwargs) + @classmethod + def dist(cls, nu, V, *args, **kwargs): warnings.warn( "The Wishart distribution can currently not be used " "for MCMC sampling. The probability of sampling a " @@ -787,34 +808,15 @@ def __init__(self, nu, V, *args, **kwargs): "https://github.com/pymc-devs/pymc3/issues/538.", UserWarning, ) - self.nu = nu = at.as_tensor_variable(nu) - self.p = p = at.as_tensor_variable(V.shape[0]) - self.V = V = at.as_tensor_variable(V) - self.mean = nu * V - self.mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan) - - def random(self, point=None, size=None): - """ - Draw random values from Wishart distribution. - - Parameters - ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). + nu = at.as_tensor_variable(intX(nu)) + V = at.as_tensor_variable(floatX(V)) - Returns - ------- - array - """ - # nu, V = draw_values([self.nu, self.V], point=point, size=size) - # size = 1 if size is None else size - # return generate_samples(stats.wishart.rvs, nu.item(), V, broadcast_shape=(size,)) + # mean = nu * V + # p = V.shape[0] + # mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan) + return super().dist([nu, V], *args, **kwargs) - def logp(self, X): + def logp(X, nu, V): """ Calculate log-probability of Wishart distribution at specified value. @@ -828,9 +830,8 @@ def logp(self, X): ------- TensorVariable """ - nu = self.nu - p = self.p - V = self.V + + p = V.shape[0] IVI = det(V) IXI = det(X) @@ -1445,6 +1446,36 @@ def _distr_parameters_for_repr(self): return ["eta", "n"] +class MatrixNormalRV(RandomVariable): + name = "matrixnormal" + ndim_supp = 2 + ndims_params = [2, 2, 2] + dtype = "floatX" + _print_name = ("MatrixNormal", "\\operatorname{MatrixNormal}") + + @classmethod + def rng_fn(cls, rng, mu, rowchol, colchol, size=None): + + size = to_tuple(size) + dist_shape = to_tuple([rowchol.shape[0], colchol.shape[0]]) + output_shape = size + dist_shape + + # Broadcasting all parameters + (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size) + rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:]) + + colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:]) + colchol = np.swapaxes(colchol, -1, -2) # Take transpose + + standard_normal = rng.standard_normal(output_shape) + samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol)) + + return samples + + +matrixnormal = MatrixNormalRV() + + class MatrixNormal(Continuous): r""" Matrix-valued normal log-likelihood. @@ -1533,175 +1564,101 @@ class MatrixNormal(Continuous): vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov, observed=data, shape=(m, n)) """ + rv_op = matrixnormal - def __init__( - self, - mu=0, + @classmethod + def dist( + cls, + mu, rowcov=None, rowchol=None, - rowtau=None, colcov=None, colchol=None, - coltau=None, shape=None, *args, **kwargs, ): - self._setup_matrices(colcov, colchol, coltau, rowcov, rowchol, rowtau) - if shape is None: - raise TypeError("shape is a required argument") - assert len(shape) == 2, "shape must have length 2: mxn" - self.shape = shape - super().__init__(shape=shape, *args, **kwargs) - self.mu = at.as_tensor_variable(mu) - self.mean = self.median = self.mode = self.mu - self.solve_lower = solve_lower_triangular - self.solve_upper = solve_upper_triangular - - def _setup_matrices(self, colcov, colchol, coltau, rowcov, rowchol, rowtau): + cholesky = Cholesky(lower=True, on_error="raise") + if mu.ndim == 1: + raise ValueError( + "1x1 Matrix was provided, " " Please use Normal distribution " "for such cases." + ) + # Among-row matrices - if len([i for i in [rowtau, rowcov, rowchol] if i is not None]) != 1: + if len([i for i in [rowcov, rowchol] if i is not None]) != 1: raise ValueError( - "Incompatible parameterization. " - "Specify exactly one of rowtau, rowcov, " - "or rowchol." + "Incompatible parameterization. " "Specify exactly one of rowcov, " "or rowchol." ) if rowcov is not None: - self.m = rowcov.shape[0] - self._rowcov_type = "cov" - rowcov = at.as_tensor_variable(rowcov) if rowcov.ndim != 2: raise ValueError("rowcov must be two dimensional.") - self.rowchol_cov = cholesky(rowcov) - self.rowcov = rowcov - elif rowtau is not None: - raise ValueError("rowtau not supported at this time") - self.m = rowtau.shape[0] - self._rowcov_type = "tau" - rowtau = at.as_tensor_variable(rowtau) - if rowtau.ndim != 2: - raise ValueError("rowtau must be two dimensional.") - self.rowchol_tau = cholesky(rowtau) - self.rowtau = rowtau + rowchol_cov = cholesky(rowcov) else: - self.m = rowchol.shape[0] - self._rowcov_type = "chol" if rowchol.ndim != 2: raise ValueError("rowchol must be two dimensional.") - self.rowchol_cov = at.as_tensor_variable(rowchol) + rowchol_cov = at.as_tensor_variable(rowchol) # Among-column matrices - if len([i for i in [coltau, colcov, colchol] if i is not None]) != 1: + if len([i for i in [colcov, colchol] if i is not None]) != 1: raise ValueError( - "Incompatible parameterization. " - "Specify exactly one of coltau, colcov, " - "or colchol." + "Incompatible parameterization. " "Specify exactly one of colcov, " "or colchol." ) if colcov is not None: - self.n = colcov.shape[0] - self._colcov_type = "cov" colcov = at.as_tensor_variable(colcov) if colcov.ndim != 2: raise ValueError("colcov must be two dimensional.") - self.colchol_cov = cholesky(colcov) - self.colcov = colcov - elif coltau is not None: - raise ValueError("coltau not supported at this time") - self.n = coltau.shape[0] - self._colcov_type = "tau" - coltau = at.as_tensor_variable(coltau) - if coltau.ndim != 2: - raise ValueError("coltau must be two dimensional.") - self.colchol_tau = cholesky(coltau) - self.coltau = coltau + colchol_cov = cholesky(colcov) else: - self.n = colchol.shape[0] - self._colcov_type = "chol" if colchol.ndim != 2: raise ValueError("colchol must be two dimensional.") - self.colchol_cov = at.as_tensor_variable(colchol) + colchol_cov = at.as_tensor_variable(colchol) - def random(self, point=None, size=None): + mu = at.as_tensor_variable(floatX(mu)) + # mean = median = mode = mu + + return super().dist([mu, rowchol_cov, colchol_cov], **kwargs) + + def logp(value, mu, rowchol, colchol): """ - Draw random values from Matrix-valued Normal distribution. + Calculate log-probability of Matrix-valued Normal distribution + at specified value. Parameters ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). + value: numeric + Value for which log-probability is calculated. Returns ------- - array + TensorVariable """ - # mu, colchol, rowchol = draw_values( - # [self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size - # ) - # size = to_tuple(size) - # dist_shape = to_tuple(self.shape) - # output_shape = size + dist_shape - # - # # Broadcasting all parameters - # (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size) - # rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:]) - # - # colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:]) - # colchol = np.swapaxes(colchol, -1, -2) # Take transpose - # - # standard_normal = np.random.standard_normal(output_shape) - # samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol)) - # return samples - - def _trquaddist(self, value): - """Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and - the logdet of colcov and rowcov.""" - delta = value - self.mu - rowchol_cov = self.rowchol_cov - colchol_cov = self.colchol_cov + # Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and + # the logdet of colcov and rowcov. + delta = value - mu # Find exponent piece by piece - right_quaddist = self.solve_lower(rowchol_cov, delta) + right_quaddist = solve_lower_triangular(rowchol, delta) quaddist = at.nlinalg.matrix_dot(right_quaddist.T, right_quaddist) - quaddist = self.solve_lower(colchol_cov, quaddist) - quaddist = self.solve_upper(colchol_cov.T, quaddist) + quaddist = solve_lower_triangular(colchol, quaddist) + quaddist = solve_upper_triangular(colchol.T, quaddist) trquaddist = at.nlinalg.trace(quaddist) - coldiag = at.diag(colchol_cov) - rowdiag = at.diag(rowchol_cov) + coldiag = at.diag(colchol) + rowdiag = at.diag(rowchol) half_collogdet = at.sum(at.log(coldiag)) # logdet(M) = 2*Tr(log(L)) half_rowlogdet = at.sum(at.log(rowdiag)) # Using Cholesky: M = L L^T - return trquaddist, half_collogdet, half_rowlogdet - - def logp(self, value): - """ - Calculate log-probability of Matrix-valued Normal distribution - at specified value. - Parameters - ---------- - value: numeric - Value for which log-probability is calculated. + m = rowchol.shape[0] + n = colchol.shape[0] - Returns - ------- - TensorVariable - """ - trquaddist, half_collogdet, half_rowlogdet = self._trquaddist(value) - m = self.m - n = self.n norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi)) return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet def _distr_parameters_for_repr(self): - mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"} - return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]] + return ["mu"] class KroneckerNormalRV(RandomVariable): diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index dbf2d425152..1d0332af7d3 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1893,8 +1893,7 @@ def test_mvnormal_init_fail(self): with pytest.raises(ValueError): x = MvNormal("x", mu=np.zeros(3), cov=np.eye(3), tau=np.eye(3), size=3) - @pytest.mark.parametrize("n", [1, 2, 3]) - @pytest.mark.xfail(reason="Distribution not refactored yet") + @pytest.mark.parametrize("n", [2, 3]) def test_matrixnormal(self, n): mat_scale = 1e3 # To reduce logp magnitude mean_scale = 0.1 @@ -1907,6 +1906,8 @@ def test_matrixnormal(self, n): "colcov": PdMatrix(n) * mat_scale, }, matrix_normal_logpdf_cov, + extra_args={"size": n}, + decimal=select_by_precision(float64=5, float32=3), ) self.check_logp( MatrixNormal, @@ -1917,6 +1918,8 @@ def test_matrixnormal(self, n): "colcov": PdMatrix(n) * mat_scale, }, matrix_normal_logpdf_cov, + extra_args={"size": n}, + decimal=select_by_precision(float64=5, float32=3), ) self.check_logp( MatrixNormal, @@ -1927,7 +1930,8 @@ def test_matrixnormal(self, n): "colchol": PdMatrixChol(n) * mat_scale, }, matrix_normal_logpdf_chol, - decimal=select_by_precision(float64=6, float32=-1), + extra_args={"size": n}, + decimal=select_by_precision(float64=5, float32=3), ) self.check_logp( MatrixNormal, @@ -1938,7 +1942,8 @@ def test_matrixnormal(self, n): "colchol": PdMatrixChol(3) * mat_scale, }, matrix_normal_logpdf_chol, - decimal=select_by_precision(float64=6, float32=0), + extra_args={"size": n}, + decimal=select_by_precision(float64=5, float32=3), ) @pytest.mark.parametrize("n", [2, 3]) @@ -2040,7 +2045,6 @@ def test_AR1(self, n): self.check_logp(AR1, Vector(R, n), {"k": Unit, "tau_e": Rplus}, AR1_logpdf) @pytest.mark.parametrize("n", [2, 3]) - @pytest.mark.xfail(reason="Distribution not refactored yet") def test_wishart(self, n): # This check compares the autodiff gradient to the numdiff gradient. # However, due to the strict constraints of the wishart, diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 0d4fb6248c5..7c9e794c49c 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -41,11 +41,8 @@ from pymc3.tests.test_distributions import ( Domain, Nat, - PdMatrix, - PdMatrixChol, R, RandomPdMatrix, - RealMatrix, Rplus, Rplusbig, Simplex, @@ -1306,6 +1303,91 @@ class TestOrderedProbit(BaseTestDistribution): ] +class TestWishart(BaseTestDistribution): + def wishart_rng_fn(self, size, nu, V, rng): + return st.wishart.rvs(np.int(nu), V, size=size, random_state=rng) + + pymc_dist = pm.Wishart + + V = np.eye(3) + pymc_dist_params = {"nu": 4, "V": V} + reference_dist_params = {"nu": 4, "V": V} + expected_rv_op_params = {"nu": 4, "V": V} + sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)] + sizes_expected = [ + (3, 3), + (3, 3), + (1, 3, 3), + (1, 3, 3), + (5, 3, 3), + (4, 5, 3, 3), + (2, 4, 2, 3, 3), + ] + reference_dist = lambda self: functools.partial( + self.wishart_rng_fn, rng=self.get_random_state() + ) + tests_to_run = [ + "check_rv_size", + "check_pymc_params_match_rv_op", + "check_pymc_draws_match_reference", + ] + + +class TestMatrixNormal(BaseTestDistribution): + + pymc_dist = pm.MatrixNormal + + mu = np.random.random((3, 3)) + row_cov = np.eye(3) + col_cov = np.eye(3) + + pymc_dist_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov} + expected_rv_op_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov} + + sizes_to_check = [None, (1), (2, 3)] + sizes_expected = [(3,), (1, 3), (2, 3, 3)] + + tests_to_run = ["check_rv_size", "check_pymc_params_match_rv_op", "test_matrix_normal"] + + def test_matrix_normal(self): + delta = 0.05 # limit for KS p-value + n_fails = 10 # Allows the KS fails a certain number of times + size = (100,) + + def ref_rand(size, mu, rowcov, colcov): + return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov, size=size) + + with pm.Model(rng_seeder=1): + matrixnormal = pm.MatrixNormal( + "mvnormal", + mu=np.random.random((3, 3)), + rowcov=np.eye(3), + colcov=np.eye(3), + size=size, + ) + check = pm.sample_prior_predictive(n_fails) + + ref_smp = ref_rand(size[0], mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3)) + + p, f = delta, n_fails + while p <= delta and f > 0: + matrixnormal_smp = check["mvnormal"][f - 1, :, :] + curr_ref_smp = ref_smp[f - 1, :, :] + + p = np.min( + [ + st.ks_2samp( + np.atleast_1d(matrixnormal_smp[..., idx]).flatten(), + np.atleast_1d(curr_ref_smp[..., idx]).flatten(), + )[1] + for idx in range(matrixnormal_smp.shape[-1]) + ] + ) + f -= 1 + + assert p > delta + + class TestInterpolated(BaseTestDistribution): def interpolated_rng_fn(self, size, mu, sigma, rng): return st.norm.rvs(loc=mu, scale=sigma, size=size) @@ -1434,70 +1516,6 @@ def ref_rand(size, mu, lam, alpha): ref_rand=ref_rand, ) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_matrix_normal(self): - def ref_rand(size, mu, rowcov, colcov): - return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov, size=size) - - # def ref_rand_tau(size, mu, tau): - # return ref_rand(size, mu, linalg.inv(tau)) - - def ref_rand_chol(size, mu, rowchol, colchol): - return ref_rand( - size, mu, rowcov=np.dot(rowchol, rowchol.T), colcov=np.dot(colchol, colchol.T) - ) - - def ref_rand_chol_transpose(size, mu, rowchol, colchol): - colchol = colchol.T - return ref_rand( - size, mu, rowcov=np.dot(rowchol, rowchol.T), colcov=np.dot(colchol, colchol.T) - ) - - def ref_rand_uchol(size, mu, rowchol, colchol): - return ref_rand( - size, mu, rowcov=np.dot(rowchol.T, rowchol), colcov=np.dot(colchol.T, colchol) - ) - - for n in [2, 3]: - pymc3_random( - pm.MatrixNormal, - {"mu": RealMatrix(n, n), "rowcov": PdMatrix(n), "colcov": PdMatrix(n)}, - size=100, - valuedomain=RealMatrix(n, n), - ref_rand=ref_rand, - ) - # pymc3_random(pm.MatrixNormal, {'mu': RealMatrix(n, n), 'tau': PdMatrix(n)}, - # size=n, valuedomain=RealMatrix(n, n), ref_rand=ref_rand_tau) - pymc3_random( - pm.MatrixNormal, - {"mu": RealMatrix(n, n), "rowchol": PdMatrixChol(n), "colchol": PdMatrixChol(n)}, - size=100, - valuedomain=RealMatrix(n, n), - ref_rand=ref_rand_chol, - ) - # pymc3_random( - # pm.MvNormal, - # {'mu': RealMatrix(n, n), 'rowchol': PdMatrixCholUpper(n), 'colchol': PdMatrixCholUpper(n)}, - # size=n, valuedomain=RealMatrix(n, n), ref_rand=ref_rand_uchol, - # extra_args={'lower': False} - # ) - - # 2 sample test fails because cov becomes different if chol is transposed beforehand. - # This implicity means we need transpose of chol after drawing values in - # MatrixNormal.random method to match stats.matrix_normal.rvs method - with pytest.raises(AssertionError): - pymc3_random( - pm.MatrixNormal, - { - "mu": RealMatrix(n, n), - "rowchol": PdMatrixChol(n), - "colchol": PdMatrixChol(n), - }, - size=100, - valuedomain=RealMatrix(n, n), - ref_rand=ref_rand_chol_transpose, - ) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") def test_dirichlet_multinomial(self): def ref_rand(size, a, n): @@ -1577,22 +1595,6 @@ def ref_rand(size, mu, sigma): pymc3_random(pm.Moyal, {"mu": R, "sigma": Rplus}, ref_rand=ref_rand) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - @pytest.mark.skip( - "Wishart random sampling not implemented.\n" - "See https://github.com/pymc-devs/pymc3/issues/538" - ) - def test_wishart(self): - # Wishart non current recommended for use: - # https://github.com/pymc-devs/pymc3/issues/538 - # for n in [2, 3]: - # pymc3_random_discrete(Wisvaluedomainhart, - # {'n': Domain([2, 3, 4, 2000]) , 'V': PdMatrix(n) }, - # valuedomain=PdMatrix(n), - # ref_rand=lambda n=None, V=None, size=None: \ - # st.wishart(V, df=n, size=size)) - pass - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") def test_lkj(self): for n in [2, 10, 50]: