Skip to content

Commit

Permalink
Add StickBreakingWeights distribution (#5200)
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored Jan 18, 2022
1 parent d52655d commit b71bb74
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/api/distributions/multivariate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ Multivariate
MatrixNormal
KroneckerNormal
CAR
StickBreakingWeights
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
MvNormal,
MvStudentT,
OrderedMultinomial,
StickBreakingWeights,
Wishart,
WishartBartlett,
)
Expand Down Expand Up @@ -159,6 +160,7 @@
"KroneckerNormal",
"MvStudentT",
"Dirichlet",
"StickBreakingWeights",
"Multinomial",
"DirichletMultinomial",
"OrderedMultinomial",
Expand Down
187 changes: 186 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
from pymc.aesaraf import floatX, intX
from pymc.distributions import transforms
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc.distributions.dist_math import check_parameters, factln, logpow, multigammaln
from pymc.distributions.dist_math import (
betaln,
check_parameters,
factln,
logpow,
multigammaln,
)
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.shape_utils import (
broadcast_dist_samples_to,
Expand All @@ -67,6 +73,7 @@
"MatrixNormal",
"KroneckerNormal",
"CAR",
"StickBreakingWeights",
]

# Step methods and advi do not catch LinAlgErrors at the
Expand Down Expand Up @@ -2167,3 +2174,181 @@ def logp(value, mu, W, alpha, tau):
tau > 0,
msg="-1 <= alpha <= 1, tau > 0",
)


class StickBreakingWeightsRV(RandomVariable):
name = "stick_breaking_weights"
ndim_supp = 1
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")

def make_node(self, rng, size, dtype, alpha, K):

alpha = at.as_tensor_variable(alpha)
K = at.as_tensor_variable(intX(K))

if alpha.ndim > 0:
raise ValueError("The concentration parameter needs to be a scalar.")

if K.ndim > 0:
raise ValueError("K must be a scalar.")

return super().make_node(rng, size, dtype, alpha, K)

def _infer_shape(self, size, dist_params, param_shapes=None):
alpha, K = dist_params

size = tuple(size)

return size + (K + 1,)

@classmethod
def rng_fn(cls, rng, alpha, K, size):
if K < 0:
raise ValueError("K needs to be positive.")

if size is None:
size = (K,)
elif isinstance(size, int):
size = (size,) + (K,)
else:
size = tuple(size) + (K,)

betas = rng.beta(1, alpha, size=size)

sticks = np.concatenate(
(
np.ones(shape=(size[:-1] + (1,))),
np.cumprod(1 - betas[..., :-1], axis=-1),
),
axis=-1,
)

weights = sticks * betas
weights = np.concatenate(
(weights, 1 - weights.sum(axis=-1)[..., np.newaxis]),
axis=-1,
)

return weights


stickbreakingweights = StickBreakingWeightsRV()


class StickBreakingWeights(Continuous):
r"""
Likelihood of truncated stick-breaking weights. The weights are generated from a
stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
:math:`k \in \{1, \ldots, K\}` and :math:`x_K = \prod_{\ell = 1}^{K} (1 - v_\ell) = 1 - \sum_{\ell=1}^K x_\ell`
with :math:`v_k \stackrel{\text{i.i.d.}}{\sim} \text{Beta}(1, \alpha)`.
.. math:
f(\mathbf{x}|\alpha, K) =
B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}
======== ===============================================
Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
such that :math:`\sum x_k = 1`
Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
for :math:`k \in \{1, \ldots, K\}` and :math:`\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
======== ===============================================
Parameters
----------
alpha: float
Concentration parameter (alpha > 0).
K: int
The number of "sticks" to break off from an initial one-unit stick. The length of the weight
vector is K + 1, where the last weight is one minus the sum of all the first sticks.
References
----------
.. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
Journal of the American Statistical Association, 96(453), 161-173.
.. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
analysis. New York: Springer.
"""
rv_op = stickbreakingweights

def __new__(cls, name, *args, **kwargs):
kwargs.setdefault("transform", transforms.simplex)
return super().__new__(cls, name, *args, **kwargs)

@classmethod
def dist(cls, alpha, K, *args, **kwargs):
alpha = at.as_tensor_variable(floatX(alpha))
K = at.as_tensor_variable(intX(K))

assert_negative_support(alpha, "alpha", "StickBreakingWeights")
assert_negative_support(K, "K", "StickBreakingWeights")

return super().dist([alpha, K], **kwargs)

def get_moment(rv, size, alpha, K):
moment = (alpha / (1 + alpha)) ** at.arange(K)
moment *= 1 / (1 + alpha)
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
if not rv_size_is_none(size):
moment_size = at.concatenate(
[
size,
[
K + 1,
],
]
)
moment = at.full(moment_size, moment)

return moment

def logp(value, alpha, K):
"""
Calculate log-probability of the distribution induced from the stick-breaking process
at specified value.
Parameters
----------
value: numeric
Value for which log-probability is calculated.
Returns
-------
TensorVariable
"""
logp = -at.sum(
at.log(
at.cumsum(
value[..., ::-1],
axis=-1,
)
),
axis=-1,
)
logp += -K * betaln(1, alpha)
logp += alpha * at.log(value[..., -1])

logp = at.switch(
at.or_(
at.any(
at.and_(at.le(value, 0), at.ge(value, 1)),
axis=-1,
),
at.or_(
at.bitwise_not(at.allclose(value.sum(-1), 1)),
at.neq(value.shape[-1], K + 1),
),
),
-np.inf,
logp,
)

return check_parameters(
logp,
alpha > 0,
K > 0,
msg="alpha > 0, K > 0",
)
35 changes: 35 additions & 0 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def polyagamma_cdf(*args, **kwargs):
Poisson,
Rice,
SkewNormal,
StickBreakingWeights,
StudentT,
Triangular,
TruncatedNormal,
Expand Down Expand Up @@ -2123,6 +2124,40 @@ def test_dirichlet_invalid(self):
valid_dist = Dirichlet.dist(a=[1, 1, 1])
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))

@pytest.mark.parametrize(
"value,alpha,K,logp",
[
(np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
(np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
(np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
(np.append(0.5 ** np.arange(1, 20), 0.5 ** 20), 5, 19, 94.20462772778092),
(
(np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
2.5,
3,
np.array([1.29317672, 1.50126157]),
),
],
)
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
with Model() as model:
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
pm.logp(sbw, value).eval(),
logp,
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
)

def test_stickbreakingweights_invalid(self):
sbw = pm.StickBreakingWeights.dist(3.0, 3)
sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7)
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf
assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf

@pytest.mark.parametrize(
"a",
[
Expand Down
30 changes: 30 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Rice,
Simulator,
SkewNormal,
StickBreakingWeights,
StudentT,
Triangular,
TruncatedNormal,
Expand Down Expand Up @@ -1087,6 +1088,35 @@ def test_matrixnormal_moment(mu, rowchol, colchol, size, expected):
def test_rice_moment(nu, sigma, size, expected):
with Model() as model:
Rice("x", nu=nu, sigma=sigma, size=size)


@pytest.mark.parametrize(
"alpha, K, size, expected",
[
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),
(
1,
7,
(13,),
np.full(
shape=(13, 8), fill_value=np.append((1 / 2) ** np.arange(7) * 1 / 2, (1 / 2) ** 7)
),
),
(
0.5,
5,
(3, 5, 7),
np.full(
shape=(3, 5, 7, 6),
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
),
),
],
)
def test_stickbreakingweights_moment(alpha, K, size, expected):
with Model() as model:
StickBreakingWeights("x", alpha=alpha, K=K, size=size)
assert_moment_is_expected(model, expected)


Expand Down
35 changes: 35 additions & 0 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,41 @@ class TestDirichlet(BaseTestDistribution):
]


class TestStickBreakingWeights(BaseTestDistribution):
pymc_dist = pm.StickBreakingWeights
pymc_dist_params = {"alpha": 2.0, "K": 19}
expected_rv_op_params = {"alpha": 2.0, "K": 19}
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
sizes_expected = [
(20,),
(17, 20),
(
5,
20,
),
(11, 5, 20),
(3, 13, 5, 20),
]
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_basic_properties",
]

def check_basic_properties(self):
default_rng = aesara.shared(np.random.default_rng(1234))
draws = pm.StickBreakingWeights.dist(
alpha=3.5,
K=19,
size=(2, 3, 5),
rng=default_rng,
).eval()

assert np.allclose(draws.sum(-1), 1)
assert np.all(draws >= 0)
assert np.all(draws <= 1)


class TestMultinomial(BaseTestDistribution):
pymc_dist = pm.Multinomial
pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
Expand Down

0 comments on commit b71bb74

Please sign in to comment.