Skip to content

Commit

Permalink
Added SBW docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Jan 17, 2022
1 parent 5a989e6 commit 6b5f5af
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
33 changes: 17 additions & 16 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2188,10 +2188,8 @@ def make_node(self, rng, size, dtype, alpha, K):
alpha = at.as_tensor_variable(alpha)
K = at.as_tensor_variable(intX(K))

# if at.lt(K, 0):
# print(at.lt(K, 0).eval())
# print(K.eval() < 0)
# raise ValueError("K needs to be positive.")
if K.eval() < 0:
raise ValueError("K needs to be positive.")

if alpha.ndim > 0:
raise ValueError("The concentration parameter needs to be a scalar.")
Expand Down Expand Up @@ -2238,34 +2236,37 @@ def rng_fn(cls, rng, alpha, K, size):

class StickBreakingWeights(Continuous):
r"""
Likelihood of truncated stick-breaking weights. The weights are generated
Likelihood of truncated stick-breaking weights. The weights are generated from a
stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
:math:`k \in \{1, \ldots, K\}` and :math
.. math:
f(\mathbf{x}|\alpha, K) =
f(\mathbf{x}|\alpha, K) =
B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}
======== ===============================================
Support :math:`x_i \in (0, 1)` for :math:`i \in \{1, \ldots, K+1\}`
such that :math:`\sum x_i = 1`
Mean :math:`\dfrac{a_i}{\sum a_i}`
Variance :math:`\dfrac{a_i - \sum a_0}{a_0^2 (a_0 + 1)}`
where :math:`a_0 = \sum a_i`
Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
such that :math:`\sum x_k = 1`
Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
for :math:`k \in \{1, \ldots, K\}` and `\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
======== ===============================================
Parameters
----------
alpha: float
Concentration parameters (alpha > 0).
Concentration parameter (alpha > 0).
K: int
The number of "sticks" to break off from an initial one-unit stick. The length
of categories is K + 1, where the last weight is one minus the sum of all the first sticks.
References
----------
.. [1] Ishwaran James
.. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
Journal of the American Statistical Association, 96(453), 161-173.
.. [2] Peter Mueller
.. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
analysis. New York: Springer.
"""
rv_op = stickbreakingweights

Expand Down Expand Up @@ -2329,7 +2330,7 @@ def logp(value, alpha, K):
),
axis=-1,
)
logp += -(K - 1) * betaln(1, alpha)
logp += -K * betaln(1, alpha)
logp += alpha * at.log(value[..., -1])

logp = at.switch(
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,7 +2138,7 @@ def test_stickbreakingweights(self, value, alpha, K, logp):
StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
model.fastlogp(pt),
model.compile_logp()(pt),
logp,
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
Expand Down

0 comments on commit 6b5f5af

Please sign in to comment.