|
17 | 17 |
|
18 | 18 | import warnings |
19 | 19 |
|
| 20 | +from functools import reduce |
| 21 | + |
20 | 22 | import aesara |
21 | 23 | import aesara.tensor as at |
22 | 24 | import numpy as np |
|
45 | 47 | from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support |
46 | 48 | from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln |
47 | 49 | from pymc3.distributions.distribution import Continuous, Discrete |
48 | | -from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker |
| 50 | +from pymc3.math import kron_diag, kron_dot |
49 | 51 |
|
50 | 52 | __all__ = [ |
51 | 53 | "MvNormal", |
@@ -1702,6 +1704,32 @@ def _distr_parameters_for_repr(self): |
1702 | 1704 | return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]] |
1703 | 1705 |
|
1704 | 1706 |
|
| 1707 | +class KroneckerNormalRV(RandomVariable): |
| 1708 | + name = "kroneckernormal" |
| 1709 | + ndim_supp = 2 |
| 1710 | + ndims_params = [1, 0, 2] |
| 1711 | + dtype = "floatX" |
| 1712 | + _print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}") |
| 1713 | + |
| 1714 | + def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None): |
| 1715 | + return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes) |
| 1716 | + |
| 1717 | + def rng_fn(self, rng, mu, sigma, *covs, size=None): |
| 1718 | + size = size if size else covs[-1] |
| 1719 | + covs = covs[:-1] if covs[-1] == size else covs |
| 1720 | + |
| 1721 | + cov = reduce(linalg.kron, covs) |
| 1722 | + |
| 1723 | + if sigma: |
| 1724 | + cov = cov + sigma ** 2 * np.eye(cov.shape[0]) |
| 1725 | + |
| 1726 | + x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size) |
| 1727 | + return x |
| 1728 | + |
| 1729 | + |
| 1730 | +kroneckernormal = KroneckerNormalRV() |
| 1731 | + |
| 1732 | + |
1705 | 1733 | class KroneckerNormal(Continuous): |
1706 | 1734 | r""" |
1707 | 1735 | Multivariate normal log-likelihood with Kronecker-structured covariance. |
@@ -1790,160 +1818,79 @@ class KroneckerNormal(Continuous): |
1790 | 1818 | ---------- |
1791 | 1819 | .. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models" |
1792 | 1820 | """ |
| 1821 | + rv_op = kroneckernormal |
1793 | 1822 |
|
1794 | | - def __init__(self, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs): |
1795 | | - self._setup(covs, chols, evds, sigma) |
1796 | | - super().__init__(*args, **kwargs) |
1797 | | - self.mu = at.as_tensor_variable(mu) |
1798 | | - self.mean = self.median = self.mode = self.mu |
| 1823 | + @classmethod |
| 1824 | + def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs): |
1799 | 1825 |
|
1800 | | - def _setup(self, covs, chols, evds, sigma): |
1801 | | - self.cholesky = Cholesky(lower=True, on_error="raise") |
1802 | 1826 | if len([i for i in [covs, chols, evds] if i is not None]) != 1: |
1803 | 1827 | raise ValueError( |
1804 | 1828 | "Incompatible parameterization. Specify exactly one of covs, chols, or evds." |
1805 | 1829 | ) |
1806 | | - self._isEVD = False |
1807 | | - self.sigma = sigma |
1808 | | - self.is_noisy = self.sigma is not None and self.sigma != 0 |
1809 | | - if covs is not None: |
1810 | | - self._cov_type = "cov" |
1811 | | - self.covs = covs |
1812 | | - if self.is_noisy: |
1813 | | - # Noise requires eigendecomposition |
1814 | | - eigh_map = map(eigh, covs) |
1815 | | - self._setup_evd(eigh_map) |
1816 | | - else: |
1817 | | - # Otherwise use cholesky as usual |
1818 | | - self.chols = list(map(self.cholesky, self.covs)) |
1819 | | - self.chol_diags = list(map(at.diag, self.chols)) |
1820 | | - self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) |
1821 | | - self.N = at.prod(self.sizes) |
1822 | | - elif chols is not None: |
1823 | | - self._cov_type = "chol" |
1824 | | - if self.is_noisy: # A strange case... |
1825 | | - # Noise requires eigendecomposition |
1826 | | - covs = [at.dot(chol, chol.T) for chol in chols] |
1827 | | - eigh_map = map(eigh, covs) |
1828 | | - self._setup_evd(eigh_map) |
1829 | | - else: |
1830 | | - self.chols = chols |
1831 | | - self.chol_diags = list(map(at.diag, self.chols)) |
1832 | | - self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) |
1833 | | - self.N = at.prod(self.sizes) |
1834 | | - else: |
1835 | | - self._cov_type = "evd" |
1836 | | - self._setup_evd(evds) |
1837 | 1830 |
|
1838 | | - def _setup_evd(self, eigh_iterable): |
1839 | | - self._isEVD = True |
1840 | | - eigs_sep, Qs = zip(*eigh_iterable) # Unzip |
1841 | | - self.Qs = list(map(at.as_tensor_variable, Qs)) |
1842 | | - self.QTs = list(map(at.transpose, self.Qs)) |
1843 | | - |
1844 | | - self.eigs_sep = list(map(at.as_tensor_variable, eigs_sep)) |
1845 | | - self.eigs = kron_diag(*self.eigs_sep) # Combine separate eigs |
1846 | | - if self.is_noisy: |
1847 | | - self.eigs += self.sigma ** 2 |
1848 | | - self.N = self.eigs.shape[0] |
1849 | | - |
1850 | | - def _setup_random(self): |
1851 | | - if not hasattr(self, "mv_params"): |
1852 | | - self.mv_params = {"mu": self.mu} |
1853 | | - if self._cov_type == "cov": |
1854 | | - cov = kronecker(*self.covs) |
1855 | | - if self.is_noisy: |
1856 | | - cov = cov + self.sigma ** 2 * at.identity_like(cov) |
1857 | | - self.mv_params["cov"] = cov |
1858 | | - elif self._cov_type == "chol": |
1859 | | - if self.is_noisy: |
1860 | | - covs = [] |
1861 | | - for eig, Q in zip(self.eigs_sep, self.Qs): |
1862 | | - cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) |
1863 | | - covs.append(cov_i) |
1864 | | - cov = kronecker(*covs) |
1865 | | - if self.is_noisy: |
1866 | | - cov = cov + self.sigma ** 2 * at.identity_like(cov) |
1867 | | - self.mv_params["chol"] = self.cholesky(cov) |
1868 | | - else: |
1869 | | - self.mv_params["chol"] = kronecker(*self.chols) |
1870 | | - elif self._cov_type == "evd": |
1871 | | - covs = [] |
1872 | | - for eig, Q in zip(self.eigs_sep, self.Qs): |
1873 | | - cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) |
1874 | | - covs.append(cov_i) |
1875 | | - cov = kronecker(*covs) |
1876 | | - if self.is_noisy: |
1877 | | - cov = cov + self.sigma ** 2 * at.identity_like(cov) |
1878 | | - self.mv_params["cov"] = cov |
| 1831 | + sigma = sigma if sigma else 0 |
1879 | 1832 |
|
1880 | | - def random(self, point=None, size=None): |
| 1833 | + if chols is not None: |
| 1834 | + covs = [chol.dot(chol.T) for chol in chols] |
| 1835 | + elif evds is not None: |
| 1836 | + eigh_iterable = evds |
| 1837 | + covs = [] |
| 1838 | + eigs_sep, Qs = zip(*eigh_iterable) # Unzip |
| 1839 | + for eig, Q in zip(eigs_sep, Qs): |
| 1840 | + cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) |
| 1841 | + covs.append(cov_i) |
| 1842 | + |
| 1843 | + mu = at.as_tensor_variable(mu) |
| 1844 | + |
| 1845 | + # mean = median = mode = mu |
| 1846 | + return super().dist([mu, sigma, *covs], **kwargs) |
| 1847 | + |
| 1848 | + def logp(value, mu, sigma, *covs): |
1881 | 1849 | """ |
1882 | | - Draw random values from Multivariate Normal distribution |
1883 | | - with Kronecker-structured covariance. |
| 1850 | + Calculate log-probability of Multivariate Normal distribution |
| 1851 | + with Kronecker-structured covariance at specified value. |
1884 | 1852 |
|
1885 | 1853 | Parameters |
1886 | 1854 | ---------- |
1887 | | - point: dict, optional |
1888 | | - Dict of variable values on which random values are to be |
1889 | | - conditioned (uses default point if not specified). |
1890 | | - size: int, optional |
1891 | | - Desired size of random sample (returns one sample if not |
1892 | | - specified). |
| 1855 | + value: numeric |
| 1856 | + Value for which log-probability is calculated. |
1893 | 1857 |
|
1894 | 1858 | Returns |
1895 | 1859 | ------- |
1896 | | - array |
| 1860 | + TensorVariable |
1897 | 1861 | """ |
1898 | | - # Expand params into terms MvNormal can understand to force consistency |
1899 | | - self._setup_random() |
1900 | | - self.mv_params["shape"] = self.shape |
1901 | | - dist = MvNormal.dist(**self.mv_params) |
1902 | | - return dist.random(point, size) |
1903 | | - |
1904 | | - def _quaddist(self, value): |
1905 | | - """Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))""" |
| 1862 | + # Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K)) |
1906 | 1863 | if value.ndim > 2 or value.ndim == 0: |
1907 | | - raise ValueError("Invalid dimension for value: %s" % value.ndim) |
| 1864 | + raise ValueError(f"Invalid dimension for value: {value.ndim}") |
1908 | 1865 | if value.ndim == 1: |
1909 | 1866 | onedim = True |
1910 | 1867 | value = value[None, :] |
1911 | 1868 | else: |
1912 | 1869 | onedim = False |
1913 | 1870 |
|
1914 | | - delta = value - self.mu |
1915 | | - if self._isEVD: |
1916 | | - sqrt_quad = kron_dot(self.QTs, delta.T) |
1917 | | - sqrt_quad = sqrt_quad / at.sqrt(self.eigs[:, None]) |
1918 | | - logdet = at.sum(at.log(self.eigs)) |
1919 | | - else: |
1920 | | - sqrt_quad = kron_solve_lower(self.chols, delta.T) |
1921 | | - logdet = 0 |
1922 | | - for chol_size, chol_diag in zip(self.sizes, self.chol_diags): |
1923 | | - logchol = at.log(chol_diag) * self.N / chol_size |
1924 | | - logdet += at.sum(2 * logchol) |
| 1871 | + delta = value - mu |
| 1872 | + |
| 1873 | + eigh_iterable = map(eigh, covs) |
| 1874 | + eigs_sep, Qs = zip(*eigh_iterable) # Unzip |
| 1875 | + Qs = list(map(at.as_tensor_variable, Qs)) |
| 1876 | + QTs = list(map(at.transpose, Qs)) |
| 1877 | + |
| 1878 | + eigs_sep = list(map(at.as_tensor_variable, eigs_sep)) |
| 1879 | + eigs = kron_diag(*eigs_sep) # Combine separate eigs |
| 1880 | + eigs += sigma ** 2 |
| 1881 | + N = eigs.shape[0] |
| 1882 | + |
| 1883 | + sqrt_quad = kron_dot(QTs, delta.T) |
| 1884 | + sqrt_quad = sqrt_quad / at.sqrt(eigs[:, None]) |
| 1885 | + logdet = at.sum(at.log(eigs)) |
| 1886 | + |
1925 | 1887 | # Square each sample |
1926 | 1888 | quad = at.batched_dot(sqrt_quad.T, sqrt_quad.T) |
1927 | 1889 | if onedim: |
1928 | 1890 | quad = quad[0] |
1929 | | - return quad, logdet |
1930 | 1891 |
|
1931 | | - def logp(self, value): |
1932 | | - """ |
1933 | | - Calculate log-probability of Multivariate Normal distribution |
1934 | | - with Kronecker-structured covariance at specified value. |
1935 | | -
|
1936 | | - Parameters |
1937 | | - ---------- |
1938 | | - value: numeric |
1939 | | - Value for which log-probability is calculated. |
1940 | | -
|
1941 | | - Returns |
1942 | | - ------- |
1943 | | - TensorVariable |
1944 | | - """ |
1945 | | - quad, logdet = self._quaddist(value) |
1946 | | - return -(quad + logdet + self.N * at.log(2 * np.pi)) / 2.0 |
| 1892 | + a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0 |
| 1893 | + return a |
1947 | 1894 |
|
1948 | 1895 | def _distr_parameters_for_repr(self): |
1949 | 1896 | return ["mu"] |
|
0 commit comments