Skip to content

Commit

Permalink
allow alpha to take batched data for StickBreakingWeights
Browse files Browse the repository at this point in the history
Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com>
  • Loading branch information
purna135 and Sayam753 committed Aug 9, 2022
1 parent ad16bf4 commit 7fde39a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
16 changes: 5 additions & 11 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,9 +2192,6 @@ def make_node(self, rng, size, dtype, alpha, K):
alpha = at.as_tensor_variable(alpha)
K = at.as_tensor_variable(intX(K))

if alpha.ndim > 0:
raise ValueError("The concentration parameter needs to be a scalar.")

if K.ndim > 0:
raise ValueError("K must be a scalar.")

Expand All @@ -2205,20 +2202,17 @@ def _infer_shape(self, size, dist_params, param_shapes=None):

size = tuple(size)

return size + (K + 1,)
return size + tuple(alpha.shape) + (K + 1,)

@classmethod
def rng_fn(cls, rng, alpha, K, size):
if K < 0:
raise ValueError("K needs to be positive.")

if size is None:
size = (K,)
elif isinstance(size, int):
size = (size,) + (K,)
else:
size = tuple(size) + (K,)
distribution_shape = alpha.shape + (K,)
size = to_tuple(size) + distribution_shape

alpha = alpha[..., np.newaxis]
betas = rng.beta(1, alpha, size=size)

sticks = np.concatenate(
Expand Down Expand Up @@ -2262,7 +2256,7 @@ class StickBreakingWeights(SimplexContinuous):
Parameters
----------
alpha : tensor_like of float
alpha: float or array_like of floats
Concentration parameter (alpha > 0).
K : tensor_like of int
The number of "sticks" to break off from an initial one-unit stick. The length of the weight
Expand Down
18 changes: 18 additions & 0 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,24 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size):
3,
np.array([1.29317672, 1.50126157]),
),
(
np.array([5, 4, 3, 2, 1]) / 15,
np.array([0.5, 1, 2], dtype="float64"),
4,
np.array([1.51263013, 2.93119375, 2.99573227]),
),
(
np.array([5, 4, 3, 2, 1]) / 15,
np.arange(1, 10, dtype="float64").reshape(3, 3),
4,
np.array(
[
[2.93119375, 2.99573227, 1.9095425],
[0.35222059, -1.4632554, -3.44201938],
[-5.53346686, -7.70739149, -9.94430955],
]
),
),
],
)
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
Expand Down
37 changes: 19 additions & 18 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,25 +1287,26 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom):


class TestStickBreakingWeights(BaseTestDistributionRandom):
pymc_dist = pm.StickBreakingWeights
pymc_dist_params = {"alpha": 2.0, "K": 19}
expected_rv_op_params = {"alpha": 2.0, "K": 19}
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
sizes_expected = [
(20,),
(17, 20),
(
5,
20,
),
(11, 5, 20),
(3, 13, 5, 20),
]
checks_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_basic_properties",
parameters = [
(np.array(3.5), 19),
(np.array([1, 2, 3], dtype="float64"), 17),
(np.arange(1, 10, dtype="float64").reshape(3, 3), 15),
(np.arange(1, 25, dtype="float64").reshape(2, 3, 4), 5),
]
for alpha, K in parameters:
pymc_dist = pm.StickBreakingWeights
pymc_dist_params = {"alpha": alpha, "K": K}
expected_rv_op_params = {"alpha": alpha, "K": K}
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
sizes_expected = []
for size in sizes_to_check:
sizes_expected.append(to_tuple(size) + alpha.shape + (K + 1,))

checks_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_basic_properties",
]

def check_basic_properties(self):
default_rng = aesara.shared(np.random.default_rng(1234))
Expand Down

0 comments on commit 7fde39a

Please sign in to comment.