diff --git a/docs/source/api/distributions/multivariate.rst b/docs/source/api/distributions/multivariate.rst index 85e3bee0267..1156f04a255 100644 --- a/docs/source/api/distributions/multivariate.rst +++ b/docs/source/api/distributions/multivariate.rst @@ -19,3 +19,4 @@ Multivariate MatrixNormal KroneckerNormal CAR + StickBreakingWeights diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index e1e899a3e0e..c1fe94c684a 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -95,6 +95,7 @@ MvNormal, MvStudentT, OrderedMultinomial, + StickBreakingWeights, Wishart, WishartBartlett, ) @@ -159,6 +160,7 @@ "KroneckerNormal", "MvStudentT", "Dirichlet", + "StickBreakingWeights", "Multinomial", "DirichletMultinomial", "OrderedMultinomial", diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f0f031ff028..1cb753eba45 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -44,7 +44,13 @@ from pymc.aesaraf import floatX, intX from pymc.distributions import transforms from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support -from pymc.distributions.dist_math import check_parameters, factln, logpow, multigammaln +from pymc.distributions.dist_math import ( + betaln, + check_parameters, + factln, + logpow, + multigammaln, +) from pymc.distributions.distribution import Continuous, Discrete from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, @@ -67,6 +73,7 @@ "MatrixNormal", "KroneckerNormal", "CAR", + "StickBreakingWeights", ] # Step methods and advi do not catch LinAlgErrors at the @@ -2167,3 +2174,181 @@ def logp(value, mu, W, alpha, tau): tau > 0, msg="-1 <= alpha <= 1, tau > 0", ) + + +class StickBreakingWeightsRV(RandomVariable): + name = "stick_breaking_weights" + ndim_supp = 1 + ndims_params = [0, 0] + dtype = "floatX" + _print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}") + + def make_node(self, rng, size, dtype, alpha, K): + + alpha = at.as_tensor_variable(alpha) + K = at.as_tensor_variable(intX(K)) + + if alpha.ndim > 0: + raise ValueError("The concentration parameter needs to be a scalar.") + + if K.ndim > 0: + raise ValueError("K must be a scalar.") + + return super().make_node(rng, size, dtype, alpha, K) + + def _infer_shape(self, size, dist_params, param_shapes=None): + alpha, K = dist_params + + size = tuple(size) + + return size + (K + 1,) + + @classmethod + def rng_fn(cls, rng, alpha, K, size): + if K < 0: + raise ValueError("K needs to be positive.") + + if size is None: + size = (K,) + elif isinstance(size, int): + size = (size,) + (K,) + else: + size = tuple(size) + (K,) + + betas = rng.beta(1, alpha, size=size) + + sticks = np.concatenate( + ( + np.ones(shape=(size[:-1] + (1,))), + np.cumprod(1 - betas[..., :-1], axis=-1), + ), + axis=-1, + ) + + weights = sticks * betas + weights = np.concatenate( + (weights, 1 - weights.sum(axis=-1)[..., np.newaxis]), + axis=-1, + ) + + return weights + + +stickbreakingweights = StickBreakingWeightsRV() + + +class StickBreakingWeights(Continuous): + r""" + Likelihood of truncated stick-breaking weights. The weights are generated from a + stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for + :math:`k \in \{1, \ldots, K\}` and :math:`x_K = \prod_{\ell = 1}^{K} (1 - v_\ell) = 1 - \sum_{\ell=1}^K x_\ell` + with :math:`v_k \stackrel{\text{i.i.d.}}{\sim} \text{Beta}(1, \alpha)`. + + .. math: + + f(\mathbf{x}|\alpha, K) = + B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1} + + ======== =============================================== + Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}` + such that :math:`\sum x_k = 1` + Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}` + for :math:`k \in \{1, \ldots, K\}` and :math:`\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}` + ======== =============================================== + + Parameters + ---------- + alpha: float + Concentration parameter (alpha > 0). + K: int + The number of "sticks" to break off from an initial one-unit stick. The length of the weight + vector is K + 1, where the last weight is one minus the sum of all the first sticks. + + References + ---------- + .. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors. + Journal of the American Statistical Association, 96(453), 161-173. + + .. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data + analysis. New York: Springer. + """ + rv_op = stickbreakingweights + + def __new__(cls, name, *args, **kwargs): + kwargs.setdefault("transform", transforms.simplex) + return super().__new__(cls, name, *args, **kwargs) + + @classmethod + def dist(cls, alpha, K, *args, **kwargs): + alpha = at.as_tensor_variable(floatX(alpha)) + K = at.as_tensor_variable(intX(K)) + + assert_negative_support(alpha, "alpha", "StickBreakingWeights") + assert_negative_support(K, "K", "StickBreakingWeights") + + return super().dist([alpha, K], **kwargs) + + def get_moment(rv, size, alpha, K): + moment = (alpha / (1 + alpha)) ** at.arange(K) + moment *= 1 / (1 + alpha) + moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1) + if not rv_size_is_none(size): + moment_size = at.concatenate( + [ + size, + [ + K + 1, + ], + ] + ) + moment = at.full(moment_size, moment) + + return moment + + def logp(value, alpha, K): + """ + Calculate log-probability of the distribution induced from the stick-breaking process + at specified value. + + Parameters + ---------- + value: numeric + Value for which log-probability is calculated. + + Returns + ------- + TensorVariable + """ + logp = -at.sum( + at.log( + at.cumsum( + value[..., ::-1], + axis=-1, + ) + ), + axis=-1, + ) + logp += -K * betaln(1, alpha) + logp += alpha * at.log(value[..., -1]) + + logp = at.switch( + at.or_( + at.any( + at.and_(at.le(value, 0), at.ge(value, 1)), + axis=-1, + ), + at.or_( + at.bitwise_not(at.allclose(value.sum(-1), 1)), + at.neq(value.shape[-1], K + 1), + ), + ), + -np.inf, + logp, + ) + + return check_parameters( + logp, + alpha > 0, + K > 0, + msg="alpha > 0, K > 0", + ) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index d2f65390f62..9b377825eda 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -109,6 +109,7 @@ def polyagamma_cdf(*args, **kwargs): Poisson, Rice, SkewNormal, + StickBreakingWeights, StudentT, Triangular, TruncatedNormal, @@ -2123,6 +2124,40 @@ def test_dirichlet_invalid(self): valid_dist = Dirichlet.dist(a=[1, 1, 1]) assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False])) + @pytest.mark.parametrize( + "value,alpha,K,logp", + [ + (np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439), + (np.tile(1, 13) / 13, 2, 12, 13.980045245672827), + (np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723), + (np.append(0.5 ** np.arange(1, 20), 0.5 ** 20), 5, 19, 94.20462772778092), + ( + (np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])), + 2.5, + 3, + np.array([1.29317672, 1.50126157]), + ), + ], + ) + def test_stickbreakingweights_logp(self, value, alpha, K, logp): + with Model() as model: + sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) + pt = {"sbw": value} + assert_almost_equal( + pm.logp(sbw, value).eval(), + logp, + decimal=select_by_precision(float64=6, float32=2), + err_msg=str(pt), + ) + + def test_stickbreakingweights_invalid(self): + sbw = pm.StickBreakingWeights.dist(3.0, 3) + sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7) + assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf + assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf + assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf + assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf + @pytest.mark.parametrize( "a", [ diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 295f1e53dab..6d67e73e2cd 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -52,6 +52,7 @@ Rice, Simulator, SkewNormal, + StickBreakingWeights, StudentT, Triangular, TruncatedNormal, @@ -1087,6 +1088,35 @@ def test_matrixnormal_moment(mu, rowchol, colchol, size, expected): def test_rice_moment(nu, sigma, size, expected): with Model() as model: Rice("x", nu=nu, sigma=sigma, size=size) + + +@pytest.mark.parametrize( + "alpha, K, size, expected", + [ + (3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)), + (5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)), + ( + 1, + 7, + (13,), + np.full( + shape=(13, 8), fill_value=np.append((1 / 2) ** np.arange(7) * 1 / 2, (1 / 2) ** 7) + ), + ), + ( + 0.5, + 5, + (3, 5, 7), + np.full( + shape=(3, 5, 7, 6), + fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5), + ), + ), + ], +) +def test_stickbreakingweights_moment(alpha, K, size, expected): + with Model() as model: + StickBreakingWeights("x", alpha=alpha, K=K, size=size) assert_moment_is_expected(model, expected) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index db4ccae82db..8afa50fd8e2 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1176,6 +1176,41 @@ class TestDirichlet(BaseTestDistribution): ] +class TestStickBreakingWeights(BaseTestDistribution): + pymc_dist = pm.StickBreakingWeights + pymc_dist_params = {"alpha": 2.0, "K": 19} + expected_rv_op_params = {"alpha": 2.0, "K": 19} + sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] + sizes_expected = [ + (20,), + (17, 20), + ( + 5, + 20, + ), + (11, 5, 20), + (3, 13, 5, 20), + ] + tests_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + "check_basic_properties", + ] + + def check_basic_properties(self): + default_rng = aesara.shared(np.random.default_rng(1234)) + draws = pm.StickBreakingWeights.dist( + alpha=3.5, + K=19, + size=(2, 3, 5), + rng=default_rng, + ).eval() + + assert np.allclose(draws.sum(-1), 1) + assert np.all(draws >= 0) + assert np.all(draws <= 1) + + class TestMultinomial(BaseTestDistribution): pymc_dist = pm.Multinomial pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}