Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding StickBreakingWeights distribution class #5200

Merged
merged 18 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/distributions/multivariate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ Multivariate
MatrixNormal
KroneckerNormal
CAR
StickBreakingWeights
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
MvNormal,
MvStudentT,
OrderedMultinomial,
StickBreakingWeights,
Wishart,
WishartBartlett,
)
Expand Down Expand Up @@ -159,6 +160,7 @@
"KroneckerNormal",
"MvStudentT",
"Dirichlet",
"StickBreakingWeights",
"Multinomial",
"DirichletMultinomial",
"OrderedMultinomial",
Expand Down
187 changes: 186 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
from pymc.aesaraf import floatX, intX
from pymc.distributions import transforms
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc.distributions.dist_math import check_parameters, factln, logpow, multigammaln
from pymc.distributions.dist_math import (
betaln,
check_parameters,
factln,
logpow,
multigammaln,
)
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.shape_utils import (
broadcast_dist_samples_to,
Expand All @@ -67,6 +73,7 @@
"MatrixNormal",
"KroneckerNormal",
"CAR",
"StickBreakingWeights",
]

# Step methods and advi do not catch LinAlgErrors at the
Expand Down Expand Up @@ -2167,3 +2174,181 @@ def logp(value, mu, W, alpha, tau):
tau > 0,
msg="-1 <= alpha <= 1, tau > 0",
)


class StickBreakingWeightsRV(RandomVariable):
name = "stick_breaking_weights"
ndim_supp = 1
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")

def make_node(self, rng, size, dtype, alpha, K):

alpha = at.as_tensor_variable(alpha)
K = at.as_tensor_variable(intX(K))

if alpha.ndim > 0:
raise ValueError("The concentration parameter needs to be a scalar.")

if K.ndim > 0:
raise ValueError("K must be a scalar.")

return super().make_node(rng, size, dtype, alpha, K)

def _infer_shape(self, size, dist_params, param_shapes=None):
alpha, K = dist_params

size = tuple(size)

return size + (K + 1,)

@classmethod
def rng_fn(cls, rng, alpha, K, size):
if K < 0:
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("K needs to be positive.")

if size is None:
size = (K,)
elif isinstance(size, int):
size = (size,) + (K,)
else:
size = tuple(size) + (K,)

betas = rng.beta(1, alpha, size=size)

sticks = np.concatenate(
(
np.ones(shape=(size[:-1] + (1,))),
np.cumprod(1 - betas[..., :-1], axis=-1),
),
axis=-1,
)

weights = sticks * betas
weights = np.concatenate(
(weights, 1 - weights.sum(axis=-1)[..., np.newaxis]),
axis=-1,
)

return weights


stickbreakingweights = StickBreakingWeightsRV()


class StickBreakingWeights(Continuous):
r"""
Likelihood of truncated stick-breaking weights. The weights are generated from a
stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
:math:`k \in \{1, \ldots, K\}` and :math:`x_K = \prod_{\ell = 1}^{K} (1 - v_\ell) = 1 - \sum_{\ell=1}^K x_\ell`
with :math:`v_k \stackrel{\text{i.i.d.}}{\sim} \text{Beta}(1, \alpha)`.

.. math:

f(\mathbf{x}|\alpha, K) =
B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}

======== ===============================================
Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
such that :math:`\sum x_k = 1`
Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
for :math:`k \in \{1, \ldots, K\}` and :math:`\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
======== ===============================================

Parameters
----------
alpha: float
Concentration parameter (alpha > 0).
K: int
The number of "sticks" to break off from an initial one-unit stick. The length of the weight
vector is K + 1, where the last weight is one minus the sum of all the first sticks.

References
----------
.. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
Journal of the American Statistical Association, 96(453), 161-173.

.. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
analysis. New York: Springer.
"""
rv_op = stickbreakingweights

def __new__(cls, name, *args, **kwargs):
kwargs.setdefault("transform", transforms.simplex)
return super().__new__(cls, name, *args, **kwargs)

@classmethod
def dist(cls, alpha, K, *args, **kwargs):
alpha = at.as_tensor_variable(floatX(alpha))
K = at.as_tensor_variable(intX(K))

assert_negative_support(alpha, "alpha", "StickBreakingWeights")
assert_negative_support(K, "K", "StickBreakingWeights")

return super().dist([alpha, K], **kwargs)

def get_moment(rv, size, alpha, K):
moment = (alpha / (1 + alpha)) ** at.arange(K)
moment *= 1 / (1 + alpha)
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
if not rv_size_is_none(size):
moment_size = at.concatenate(
[
size,
[
K + 1,
],
]
)
moment = at.full(moment_size, moment)

return moment

def logp(value, alpha, K):
"""
Calculate log-probability of the distribution induced from the stick-breaking process
at specified value.

Parameters
----------
value: numeric
Value for which log-probability is calculated.

larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
TensorVariable
"""
logp = -at.sum(
at.log(
at.cumsum(
value[..., ::-1],
axis=-1,
)
),
axis=-1,
)
logp += -K * betaln(1, alpha)
logp += alpha * at.log(value[..., -1])

logp = at.switch(
at.or_(
at.any(
at.and_(at.le(value, 0), at.ge(value, 1)),
axis=-1,
),
at.or_(
at.bitwise_not(at.allclose(value.sum(-1), 1)),
at.neq(value.shape[-1], K + 1),
),
),
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
-np.inf,
logp,
)

return check_parameters(
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
logp,
alpha > 0,
K > 0,
msg="alpha > 0, K > 0",
)
35 changes: 35 additions & 0 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def polyagamma_cdf(*args, **kwargs):
Poisson,
Rice,
SkewNormal,
StickBreakingWeights,
StudentT,
Triangular,
TruncatedNormal,
Expand Down Expand Up @@ -2123,6 +2124,40 @@ def test_dirichlet_invalid(self):
valid_dist = Dirichlet.dist(a=[1, 1, 1])
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))

@pytest.mark.parametrize(
"value,alpha,K,logp",
[
(np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
(np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
(np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
(np.append(0.5 ** np.arange(1, 20), 0.5 ** 20), 5, 19, 94.20462772778092),
(
(np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
2.5,
3,
np.array([1.29317672, 1.50126157]),
),
],
)
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
with Model() as model:
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
pm.logp(sbw, value).eval(),
logp,
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
)
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved

def test_stickbreakingweights_invalid(self):
sbw = pm.StickBreakingWeights.dist(3.0, 3)
sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7)
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf
assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf

ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"a",
[
Expand Down
30 changes: 30 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Rice,
Simulator,
SkewNormal,
StickBreakingWeights,
StudentT,
Triangular,
TruncatedNormal,
Expand Down Expand Up @@ -1087,6 +1088,35 @@ def test_matrixnormal_moment(mu, rowchol, colchol, size, expected):
def test_rice_moment(nu, sigma, size, expected):
with Model() as model:
Rice("x", nu=nu, sigma=sigma, size=size)


@pytest.mark.parametrize(
"alpha, K, size, expected",
[
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),
(
1,
7,
(13,),
np.full(
shape=(13, 8), fill_value=np.append((1 / 2) ** np.arange(7) * 1 / 2, (1 / 2) ** 7)
),
),
(
0.5,
5,
(3, 5, 7),
np.full(
shape=(3, 5, 7, 6),
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
),
),
],
)
def test_stickbreakingweights_moment(alpha, K, size, expected):
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
with Model() as model:
StickBreakingWeights("x", alpha=alpha, K=K, size=size)
assert_moment_is_expected(model, expected)


Expand Down
35 changes: 35 additions & 0 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,41 @@ class TestDirichlet(BaseTestDistribution):
]


class TestStickBreakingWeights(BaseTestDistribution):
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
pymc_dist = pm.StickBreakingWeights
pymc_dist_params = {"alpha": 2.0, "K": 19}
expected_rv_op_params = {"alpha": 2.0, "K": 19}
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
sizes_expected = [
(20,),
(17, 20),
(
5,
20,
),
(11, 5, 20),
(3, 13, 5, 20),
]
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_basic_properties",
]

larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
def check_basic_properties(self):
default_rng = aesara.shared(np.random.default_rng(1234))
draws = pm.StickBreakingWeights.dist(
alpha=3.5,
K=19,
size=(2, 3, 5),
rng=default_rng,
).eval()

assert np.allclose(draws.sum(-1), 1)
assert np.all(draws >= 0)
assert np.all(draws <= 1)

ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

class TestMultinomial(BaseTestDistribution):
pymc_dist = pm.Multinomial
pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
Expand Down