From b71bb74e12d06daeb6ef0facdfecc369c7ae256c Mon Sep 17 00:00:00 2001
From: larryshamalama <larry.dong@mail.utoronto.ca>
Date: Tue, 18 Jan 2022 14:19:29 -0500
Subject: [PATCH] Add `StickBreakingWeights` distribution (#5200)

---
 .../source/api/distributions/multivariate.rst |   1 +
 pymc/distributions/__init__.py                |   2 +
 pymc/distributions/multivariate.py            | 187 +++++++++++++++++-
 pymc/tests/test_distributions.py              |  35 ++++
 pymc/tests/test_distributions_moments.py      |  30 +++
 pymc/tests/test_distributions_random.py       |  35 ++++
 6 files changed, 289 insertions(+), 1 deletion(-)

diff --git a/docs/source/api/distributions/multivariate.rst b/docs/source/api/distributions/multivariate.rst
index 85e3bee0267..1156f04a255 100644
--- a/docs/source/api/distributions/multivariate.rst
+++ b/docs/source/api/distributions/multivariate.rst
@@ -19,3 +19,4 @@ Multivariate
    MatrixNormal
    KroneckerNormal
    CAR
+   StickBreakingWeights
diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py
index e1e899a3e0e..c1fe94c684a 100644
--- a/pymc/distributions/__init__.py
+++ b/pymc/distributions/__init__.py
@@ -95,6 +95,7 @@
     MvNormal,
     MvStudentT,
     OrderedMultinomial,
+    StickBreakingWeights,
     Wishart,
     WishartBartlett,
 )
@@ -159,6 +160,7 @@
     "KroneckerNormal",
     "MvStudentT",
     "Dirichlet",
+    "StickBreakingWeights",
     "Multinomial",
     "DirichletMultinomial",
     "OrderedMultinomial",
diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index f0f031ff028..1cb753eba45 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -44,7 +44,13 @@
 from pymc.aesaraf import floatX, intX
 from pymc.distributions import transforms
 from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
-from pymc.distributions.dist_math import check_parameters, factln, logpow, multigammaln
+from pymc.distributions.dist_math import (
+    betaln,
+    check_parameters,
+    factln,
+    logpow,
+    multigammaln,
+)
 from pymc.distributions.distribution import Continuous, Discrete
 from pymc.distributions.shape_utils import (
     broadcast_dist_samples_to,
@@ -67,6 +73,7 @@
     "MatrixNormal",
     "KroneckerNormal",
     "CAR",
+    "StickBreakingWeights",
 ]
 
 # Step methods and advi do not catch LinAlgErrors at the
@@ -2167,3 +2174,181 @@ def logp(value, mu, W, alpha, tau):
             tau > 0,
             msg="-1 <= alpha <= 1, tau > 0",
         )
+
+
+class StickBreakingWeightsRV(RandomVariable):
+    name = "stick_breaking_weights"
+    ndim_supp = 1
+    ndims_params = [0, 0]
+    dtype = "floatX"
+    _print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")
+
+    def make_node(self, rng, size, dtype, alpha, K):
+
+        alpha = at.as_tensor_variable(alpha)
+        K = at.as_tensor_variable(intX(K))
+
+        if alpha.ndim > 0:
+            raise ValueError("The concentration parameter needs to be a scalar.")
+
+        if K.ndim > 0:
+            raise ValueError("K must be a scalar.")
+
+        return super().make_node(rng, size, dtype, alpha, K)
+
+    def _infer_shape(self, size, dist_params, param_shapes=None):
+        alpha, K = dist_params
+
+        size = tuple(size)
+
+        return size + (K + 1,)
+
+    @classmethod
+    def rng_fn(cls, rng, alpha, K, size):
+        if K < 0:
+            raise ValueError("K needs to be positive.")
+
+        if size is None:
+            size = (K,)
+        elif isinstance(size, int):
+            size = (size,) + (K,)
+        else:
+            size = tuple(size) + (K,)
+
+        betas = rng.beta(1, alpha, size=size)
+
+        sticks = np.concatenate(
+            (
+                np.ones(shape=(size[:-1] + (1,))),
+                np.cumprod(1 - betas[..., :-1], axis=-1),
+            ),
+            axis=-1,
+        )
+
+        weights = sticks * betas
+        weights = np.concatenate(
+            (weights, 1 - weights.sum(axis=-1)[..., np.newaxis]),
+            axis=-1,
+        )
+
+        return weights
+
+
+stickbreakingweights = StickBreakingWeightsRV()
+
+
+class StickBreakingWeights(Continuous):
+    r"""
+    Likelihood of truncated stick-breaking weights. The weights are generated from a
+    stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
+    :math:`k \in \{1, \ldots, K\}` and :math:`x_K = \prod_{\ell = 1}^{K} (1 - v_\ell) = 1 - \sum_{\ell=1}^K x_\ell`
+    with :math:`v_k \stackrel{\text{i.i.d.}}{\sim} \text{Beta}(1, \alpha)`.
+
+    .. math:
+
+        f(\mathbf{x}|\alpha, K) =
+            B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}
+
+    ========  ===============================================
+    Support   :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
+              such that :math:`\sum x_k = 1`
+    Mean      :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
+              for :math:`k \in \{1, \ldots, K\}` and :math:`\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
+    ========  ===============================================
+
+    Parameters
+    ----------
+    alpha: float
+        Concentration parameter (alpha > 0).
+    K: int
+        The number of "sticks" to break off from an initial one-unit stick. The length of the weight
+        vector is K + 1, where the last weight is one minus the sum of all the first sticks.
+
+    References
+    ----------
+    .. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
+           Journal of the American Statistical Association, 96(453), 161-173.
+
+    .. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
+           analysis. New York: Springer.
+    """
+    rv_op = stickbreakingweights
+
+    def __new__(cls, name, *args, **kwargs):
+        kwargs.setdefault("transform", transforms.simplex)
+        return super().__new__(cls, name, *args, **kwargs)
+
+    @classmethod
+    def dist(cls, alpha, K, *args, **kwargs):
+        alpha = at.as_tensor_variable(floatX(alpha))
+        K = at.as_tensor_variable(intX(K))
+
+        assert_negative_support(alpha, "alpha", "StickBreakingWeights")
+        assert_negative_support(K, "K", "StickBreakingWeights")
+
+        return super().dist([alpha, K], **kwargs)
+
+    def get_moment(rv, size, alpha, K):
+        moment = (alpha / (1 + alpha)) ** at.arange(K)
+        moment *= 1 / (1 + alpha)
+        moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
+        if not rv_size_is_none(size):
+            moment_size = at.concatenate(
+                [
+                    size,
+                    [
+                        K + 1,
+                    ],
+                ]
+            )
+            moment = at.full(moment_size, moment)
+
+        return moment
+
+    def logp(value, alpha, K):
+        """
+        Calculate log-probability of the distribution induced from the stick-breaking process
+        at specified value.
+
+        Parameters
+        ----------
+        value: numeric
+            Value for which log-probability is calculated.
+
+        Returns
+        -------
+        TensorVariable
+        """
+        logp = -at.sum(
+            at.log(
+                at.cumsum(
+                    value[..., ::-1],
+                    axis=-1,
+                )
+            ),
+            axis=-1,
+        )
+        logp += -K * betaln(1, alpha)
+        logp += alpha * at.log(value[..., -1])
+
+        logp = at.switch(
+            at.or_(
+                at.any(
+                    at.and_(at.le(value, 0), at.ge(value, 1)),
+                    axis=-1,
+                ),
+                at.or_(
+                    at.bitwise_not(at.allclose(value.sum(-1), 1)),
+                    at.neq(value.shape[-1], K + 1),
+                ),
+            ),
+            -np.inf,
+            logp,
+        )
+
+        return check_parameters(
+            logp,
+            alpha > 0,
+            K > 0,
+            msg="alpha > 0, K > 0",
+        )
diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index d2f65390f62..9b377825eda 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -109,6 +109,7 @@ def polyagamma_cdf(*args, **kwargs):
     Poisson,
     Rice,
     SkewNormal,
+    StickBreakingWeights,
     StudentT,
     Triangular,
     TruncatedNormal,
@@ -2123,6 +2124,40 @@ def test_dirichlet_invalid(self):
         valid_dist = Dirichlet.dist(a=[1, 1, 1])
         assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
 
+    @pytest.mark.parametrize(
+        "value,alpha,K,logp",
+        [
+            (np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
+            (np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
+            (np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
+            (np.append(0.5 ** np.arange(1, 20), 0.5 ** 20), 5, 19, 94.20462772778092),
+            (
+                (np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
+                2.5,
+                3,
+                np.array([1.29317672, 1.50126157]),
+            ),
+        ],
+    )
+    def test_stickbreakingweights_logp(self, value, alpha, K, logp):
+        with Model() as model:
+            sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
+        pt = {"sbw": value}
+        assert_almost_equal(
+            pm.logp(sbw, value).eval(),
+            logp,
+            decimal=select_by_precision(float64=6, float32=2),
+            err_msg=str(pt),
+        )
+
+    def test_stickbreakingweights_invalid(self):
+        sbw = pm.StickBreakingWeights.dist(3.0, 3)
+        sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7)
+        assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf
+        assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf
+        assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
+        assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf
+
     @pytest.mark.parametrize(
         "a",
         [
diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py
index 295f1e53dab..6d67e73e2cd 100644
--- a/pymc/tests/test_distributions_moments.py
+++ b/pymc/tests/test_distributions_moments.py
@@ -52,6 +52,7 @@
     Rice,
     Simulator,
     SkewNormal,
+    StickBreakingWeights,
     StudentT,
     Triangular,
     TruncatedNormal,
@@ -1087,6 +1088,35 @@ def test_matrixnormal_moment(mu, rowchol, colchol, size, expected):
 def test_rice_moment(nu, sigma, size, expected):
     with Model() as model:
         Rice("x", nu=nu, sigma=sigma, size=size)
+
+
+@pytest.mark.parametrize(
+    "alpha, K, size, expected",
+    [
+        (3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
+        (5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),
+        (
+            1,
+            7,
+            (13,),
+            np.full(
+                shape=(13, 8), fill_value=np.append((1 / 2) ** np.arange(7) * 1 / 2, (1 / 2) ** 7)
+            ),
+        ),
+        (
+            0.5,
+            5,
+            (3, 5, 7),
+            np.full(
+                shape=(3, 5, 7, 6),
+                fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
+            ),
+        ),
+    ],
+)
+def test_stickbreakingweights_moment(alpha, K, size, expected):
+    with Model() as model:
+        StickBreakingWeights("x", alpha=alpha, K=K, size=size)
     assert_moment_is_expected(model, expected)
 
 
diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index db4ccae82db..8afa50fd8e2 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -1176,6 +1176,41 @@ class TestDirichlet(BaseTestDistribution):
     ]
 
 
+class TestStickBreakingWeights(BaseTestDistribution):
+    pymc_dist = pm.StickBreakingWeights
+    pymc_dist_params = {"alpha": 2.0, "K": 19}
+    expected_rv_op_params = {"alpha": 2.0, "K": 19}
+    sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
+    sizes_expected = [
+        (20,),
+        (17, 20),
+        (
+            5,
+            20,
+        ),
+        (11, 5, 20),
+        (3, 13, 5, 20),
+    ]
+    tests_to_run = [
+        "check_pymc_params_match_rv_op",
+        "check_rv_size",
+        "check_basic_properties",
+    ]
+
+    def check_basic_properties(self):
+        default_rng = aesara.shared(np.random.default_rng(1234))
+        draws = pm.StickBreakingWeights.dist(
+            alpha=3.5,
+            K=19,
+            size=(2, 3, 5),
+            rng=default_rng,
+        ).eval()
+
+        assert np.allclose(draws.sum(-1), 1)
+        assert np.all(draws >= 0)
+        assert np.all(draws <= 1)
+
+
 class TestMultinomial(BaseTestDistribution):
     pymc_dist = pm.Multinomial
     pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}