From 7fde39a210ec6b319c979e96d75900034a244ff8 Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Wed, 10 Aug 2022 01:02:23 +0530 Subject: [PATCH 1/7] allow alpha to take batched data for StickBreakingWeights Co-authored-by: Sayam Kumar --- pymc/distributions/multivariate.py | 16 ++++------- pymc/tests/test_distributions.py | 18 ++++++++++++ pymc/tests/test_distributions_random.py | 37 +++++++++++++------------ 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 99568861bf1..3a8856a7fd6 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2192,9 +2192,6 @@ def make_node(self, rng, size, dtype, alpha, K): alpha = at.as_tensor_variable(alpha) K = at.as_tensor_variable(intX(K)) - if alpha.ndim > 0: - raise ValueError("The concentration parameter needs to be a scalar.") - if K.ndim > 0: raise ValueError("K must be a scalar.") @@ -2205,20 +2202,17 @@ def _infer_shape(self, size, dist_params, param_shapes=None): size = tuple(size) - return size + (K + 1,) + return size + tuple(alpha.shape) + (K + 1,) @classmethod def rng_fn(cls, rng, alpha, K, size): if K < 0: raise ValueError("K needs to be positive.") - if size is None: - size = (K,) - elif isinstance(size, int): - size = (size,) + (K,) - else: - size = tuple(size) + (K,) + distribution_shape = alpha.shape + (K,) + size = to_tuple(size) + distribution_shape + alpha = alpha[..., np.newaxis] betas = rng.beta(1, alpha, size=size) sticks = np.concatenate( @@ -2262,7 +2256,7 @@ class StickBreakingWeights(SimplexContinuous): Parameters ---------- - alpha : tensor_like of float + alpha: float or array_like of floats Concentration parameter (alpha > 0). K : tensor_like of int The number of "sticks" to break off from an initial one-unit stick. The length of the weight diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 28705ba1b83..a5c7aba91f8 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2291,6 +2291,24 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size): 3, np.array([1.29317672, 1.50126157]), ), + ( + np.array([5, 4, 3, 2, 1]) / 15, + np.array([0.5, 1, 2], dtype="float64"), + 4, + np.array([1.51263013, 2.93119375, 2.99573227]), + ), + ( + np.array([5, 4, 3, 2, 1]) / 15, + np.arange(1, 10, dtype="float64").reshape(3, 3), + 4, + np.array( + [ + [2.93119375, 2.99573227, 1.9095425], + [0.35222059, -1.4632554, -3.44201938], + [-5.53346686, -7.70739149, -9.94430955], + ] + ), + ), ], ) def test_stickbreakingweights_logp(self, value, alpha, K, logp): diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 5818c5adf5d..9a927bb381c 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1287,25 +1287,26 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom): class TestStickBreakingWeights(BaseTestDistributionRandom): - pymc_dist = pm.StickBreakingWeights - pymc_dist_params = {"alpha": 2.0, "K": 19} - expected_rv_op_params = {"alpha": 2.0, "K": 19} - sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] - sizes_expected = [ - (20,), - (17, 20), - ( - 5, - 20, - ), - (11, 5, 20), - (3, 13, 5, 20), - ] - checks_to_run = [ - "check_pymc_params_match_rv_op", - "check_rv_size", - "check_basic_properties", + parameters = [ + (np.array(3.5), 19), + (np.array([1, 2, 3], dtype="float64"), 17), + (np.arange(1, 10, dtype="float64").reshape(3, 3), 15), + (np.arange(1, 25, dtype="float64").reshape(2, 3, 4), 5), ] + for alpha, K in parameters: + pymc_dist = pm.StickBreakingWeights + pymc_dist_params = {"alpha": alpha, "K": K} + expected_rv_op_params = {"alpha": alpha, "K": K} + sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] + sizes_expected = [] + for size in sizes_to_check: + sizes_expected.append(to_tuple(size) + alpha.shape + (K + 1,)) + + checks_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + "check_basic_properties", + ] def check_basic_properties(self): default_rng = aesara.shared(np.random.default_rng(1234)) From 11cd987275be04d3e4e6c7dff5e0e9fb7dfd58cf Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Mon, 15 Aug 2022 15:24:17 +0530 Subject: [PATCH 2/7] add supp_shape_from_params() and fix shape --- pymc/distributions/multivariate.py | 15 ++++------ pymc/tests/test_distributions_random.py | 37 ++++++++++++------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 3a8856a7fd6..313dcc7ca63 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2197,20 +2197,17 @@ def make_node(self, rng, size, dtype, alpha, K): return super().make_node(rng, size, dtype, alpha, K) - def _infer_shape(self, size, dist_params, param_shapes=None): - alpha, K = dist_params - - size = tuple(size) - - return size + tuple(alpha.shape) + (K + 1,) + def _supp_shape_from_params(self, dist_params, **kwargs): + K = dist_params[1] + return (K + 1,) @classmethod def rng_fn(cls, rng, alpha, K, size): if K < 0: raise ValueError("K needs to be positive.") - distribution_shape = alpha.shape + (K,) - size = to_tuple(size) + distribution_shape + size = to_tuple(size) + size = np.broadcast_shapes(alpha.shape, size) + (K,) alpha = alpha[..., np.newaxis] betas = rng.beta(1, alpha, size=size) @@ -2256,7 +2253,7 @@ class StickBreakingWeights(SimplexContinuous): Parameters ---------- - alpha: float or array_like of floats + alpha : tensor_like of float Concentration parameter (alpha > 0). K : tensor_like of int The number of "sticks" to break off from an initial one-unit stick. The length of the weight diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 9a927bb381c..5818c5adf5d 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1287,26 +1287,25 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom): class TestStickBreakingWeights(BaseTestDistributionRandom): - parameters = [ - (np.array(3.5), 19), - (np.array([1, 2, 3], dtype="float64"), 17), - (np.arange(1, 10, dtype="float64").reshape(3, 3), 15), - (np.arange(1, 25, dtype="float64").reshape(2, 3, 4), 5), + pymc_dist = pm.StickBreakingWeights + pymc_dist_params = {"alpha": 2.0, "K": 19} + expected_rv_op_params = {"alpha": 2.0, "K": 19} + sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] + sizes_expected = [ + (20,), + (17, 20), + ( + 5, + 20, + ), + (11, 5, 20), + (3, 13, 5, 20), + ] + checks_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + "check_basic_properties", ] - for alpha, K in parameters: - pymc_dist = pm.StickBreakingWeights - pymc_dist_params = {"alpha": alpha, "K": K} - expected_rv_op_params = {"alpha": alpha, "K": K} - sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] - sizes_expected = [] - for size in sizes_to_check: - sizes_expected.append(to_tuple(size) + alpha.shape + (K + 1,)) - - checks_to_run = [ - "check_pymc_params_match_rv_op", - "check_rv_size", - "check_basic_properties", - ] def check_basic_properties(self): default_rng = aesara.shared(np.random.default_rng(1234)) From afdd31a41d92e20419c3486a83624c008c8c9811 Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Tue, 16 Aug 2022 20:47:47 +0530 Subject: [PATCH 3/7] fixed shape and added test for batched alpha --- pymc/distributions/multivariate.py | 8 ++- pymc/tests/test_distributions.py | 73 +++++++++++++++++++------ pymc/tests/test_distributions_random.py | 12 ++++ 3 files changed, 72 insertions(+), 21 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 313dcc7ca63..1ecb6d51f3f 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2206,10 +2206,12 @@ def rng_fn(cls, rng, alpha, K, size): if K < 0: raise ValueError("K needs to be positive.") - size = to_tuple(size) - size = np.broadcast_shapes(alpha.shape, size) + (K,) + if size is None: + size = alpha.shape + (K,) + alpha = alpha[..., np.newaxis] + else: + size = size + (K,) - alpha = alpha[..., np.newaxis] betas = rng.beta(1, alpha, size=size) sticks = np.concatenate( diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index a5c7aba91f8..01ad6637934 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -24,6 +24,7 @@ from aesara.tensor.random.utils import broadcast_params from pymc.distributions.continuous import get_tau_sigma +from pymc.distributions.dist_math import betaln from pymc.util import UNSET try: @@ -952,6 +953,38 @@ def test_hierarchical_obs_logp(): assert not any(isinstance(o, RandomVariable) for o in ops) +def _stickbreakingweights_logpdf(value, alpha, K): + logp = -at.sum( + at.log( + at.cumsum( + value[..., ::-1], + axis=-1, + ) + ), + axis=-1, + ) + logp += -K * betaln(1, alpha) + logp += alpha * at.log(value[..., -1]) + logp = at.switch( + at.or_( + at.any( + at.and_(at.le(value, 0), at.ge(value, 1)), + axis=-1, + ), + at.or_( + at.bitwise_not(at.allclose(value.sum(-1), 1)), + at.neq(value.shape[-1], K + 1), + ), + ), + -np.inf, + logp, + ) + return logp.eval() + + +stickbreakingweights_logpdf = np.vectorize(_stickbreakingweights_logpdf, signature="(n),(),()->()") + + class TestMatchesScipy: def test_uniform(self): check_logp( @@ -2291,24 +2324,6 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size): 3, np.array([1.29317672, 1.50126157]), ), - ( - np.array([5, 4, 3, 2, 1]) / 15, - np.array([0.5, 1, 2], dtype="float64"), - 4, - np.array([1.51263013, 2.93119375, 2.99573227]), - ), - ( - np.array([5, 4, 3, 2, 1]) / 15, - np.arange(1, 10, dtype="float64").reshape(3, 3), - 4, - np.array( - [ - [2.93119375, 2.99573227, 1.9095425], - [0.35222059, -1.4632554, -3.44201938], - [-5.53346686, -7.70739149, -9.94430955], - ] - ), - ), ], ) def test_stickbreakingweights_logp(self, value, alpha, K, logp): @@ -2330,6 +2345,28 @@ def test_stickbreakingweights_invalid(self): assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf + @pytest.mark.parametrize( + "value, alpha, K", + [ + (np.array([5, 4, 3, 2, 1]) / 15, [0.5, 1.0, 2.0], 19), + ( + np.append(0.5 ** np.arange(1, 20), 0.5**20), + np.arange(1, 7, dtype="float64").reshape(2, 3), + 4, + ), + ], + ) + def test_stickbreakingweights_vectorized(self, value, alpha, K): + with Model(): + sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) + pt = {"sbw": value} + assert_almost_equal( + pm.logp(sbw, value).eval(), + stickbreakingweights_logpdf(value, alpha, K), + decimal=select_by_precision(float64=6, float32=2), + err_msg=str(pt), + ) + @aesara.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): with Model(): diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 5818c5adf5d..52aae93bd99 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1321,6 +1321,18 @@ def check_basic_properties(self): assert np.all(draws <= 1) +class TestStickBreakingWeights_1D_alpha(BaseTestDistributionRandom): + pymc_dist = pm.StickBreakingWeights + pymc_dist_params = {"alpha": [1.0, 2.0, 3.0], "K": 19} + expected_rv_op_params = {"alpha": [1.0, 2.0, 3.0], "K": 19} + sizes_to_check = [None] + sizes_expected = [(3, 20)] + checks_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + ] + + class TestCategorical(BaseTestDistributionRandom): pymc_dist = pm.Categorical pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])} From d933c070352deb3b1a3261e59cf07de9a511f93c Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Tue, 16 Aug 2022 22:48:53 +0530 Subject: [PATCH 4/7] refactored tests --- pymc/distributions/multivariate.py | 8 ++--- pymc/tests/test_distributions.py | 47 +++++++------------------ pymc/tests/test_distributions_random.py | 4 +-- 3 files changed, 17 insertions(+), 42 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 1ecb6d51f3f..486d74efcc7 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2206,11 +2206,9 @@ def rng_fn(cls, rng, alpha, K, size): if K < 0: raise ValueError("K needs to be positive.") - if size is None: - size = alpha.shape + (K,) - alpha = alpha[..., np.newaxis] - else: - size = size + (K,) + size = size if size is not None else alpha.shape + size = size + (K,) + alpha = alpha[..., np.newaxis] betas = rng.beta(1, alpha, size=size) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 01ad6637934..f215af665cf 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -23,8 +23,8 @@ from aeppl.logprob import ParameterValueError from aesara.tensor.random.utils import broadcast_params +from pymc.aesaraf import compile_pymc from pymc.distributions.continuous import get_tau_sigma -from pymc.distributions.dist_math import betaln from pymc.util import UNSET try: @@ -953,38 +953,6 @@ def test_hierarchical_obs_logp(): assert not any(isinstance(o, RandomVariable) for o in ops) -def _stickbreakingweights_logpdf(value, alpha, K): - logp = -at.sum( - at.log( - at.cumsum( - value[..., ::-1], - axis=-1, - ) - ), - axis=-1, - ) - logp += -K * betaln(1, alpha) - logp += alpha * at.log(value[..., -1]) - logp = at.switch( - at.or_( - at.any( - at.and_(at.le(value, 0), at.ge(value, 1)), - axis=-1, - ), - at.or_( - at.bitwise_not(at.allclose(value.sum(-1), 1)), - at.neq(value.shape[-1], K + 1), - ), - ), - -np.inf, - logp, - ) - return logp.eval() - - -stickbreakingweights_logpdf = np.vectorize(_stickbreakingweights_logpdf, signature="(n),(),()->()") - - class TestMatchesScipy: def test_uniform(self): check_logp( @@ -2348,15 +2316,24 @@ def test_stickbreakingweights_invalid(self): @pytest.mark.parametrize( "value, alpha, K", [ - (np.array([5, 4, 3, 2, 1]) / 15, [0.5, 1.0, 2.0], 19), + (np.array([5, 4, 3, 2, 1]) / 15, [0.5, 1.0, 2.0], 4), ( np.append(0.5 ** np.arange(1, 20), 0.5**20), np.arange(1, 7, dtype="float64").reshape(2, 3), - 4, + 19, ), ], ) def test_stickbreakingweights_vectorized(self, value, alpha, K): + _value = at.vector() + _alpha = at.scalar() + _k = at.iscalar() + _logp = logp(StickBreakingWeights.dist(_alpha, _k), _value) + _stickbreakingweights_logpdf = compile_pymc([_value, _alpha, _k], _logp) + stickbreakingweights_logpdf = np.vectorize( + _stickbreakingweights_logpdf, signature="(n),(),()->()" + ) + with Model(): sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) pt = {"sbw": value} diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 52aae93bd99..de7e6cd0884 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1325,8 +1325,8 @@ class TestStickBreakingWeights_1D_alpha(BaseTestDistributionRandom): pymc_dist = pm.StickBreakingWeights pymc_dist_params = {"alpha": [1.0, 2.0, 3.0], "K": 19} expected_rv_op_params = {"alpha": [1.0, 2.0, 3.0], "K": 19} - sizes_to_check = [None] - sizes_expected = [(3, 20)] + sizes_to_check = [None, (3,), (5, 3)] + sizes_expected = [(3, 20), (3, 20), (5, 3, 20)] checks_to_run = [ "check_pymc_params_match_rv_op", "check_rv_size", From 4a50281df7bd8fd5d9c2ea444f6bda07d34d758f Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Thu, 18 Aug 2022 14:33:24 +0530 Subject: [PATCH 5/7] made pytest fixture and added test for batched alpha --- pymc/distributions/multivariate.py | 2 +- pymc/tests/test_distributions.py | 66 ++++++++++-------------------- 2 files changed, 22 insertions(+), 46 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 486d74efcc7..e82af48ac9d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2206,7 +2206,7 @@ def rng_fn(cls, rng, alpha, K, size): if K < 0: raise ValueError("K needs to be positive.") - size = size if size is not None else alpha.shape + size = to_tuple(size) if size is not None else alpha.shape size = size + (K,) alpha = alpha[..., np.newaxis] diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index f215af665cf..5f8e09dd95a 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -953,6 +953,15 @@ def test_hierarchical_obs_logp(): assert not any(isinstance(o, RandomVariable) for o in ops) +@pytest.fixture(scope="module") +def _compile_stickbreakingweights_logpdf(): + _value = at.vector() + _alpha = at.scalar() + _k = at.iscalar() + _logp = logp(StickBreakingWeights.dist(_alpha, _k), _value) + return compile_pymc([_value, _alpha, _k], _logp) + + class TestMatchesScipy: def test_uniform(self): check_logp( @@ -2280,27 +2289,25 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size): ) @pytest.mark.parametrize( - "value,alpha,K,logp", + "alpha,K", [ - (np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439), - (np.tile(1, 13) / 13, 2, 12, 13.980045245672827), - (np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723), - (np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092), - ( - (np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])), - 2.5, - 3, - np.array([1.29317672, 1.50126157]), - ), + (0.5, 4), + (2, 12), + (np.array([0.5, 1.0, 2.0]), 3), + (np.arange(1, 7, dtype="float64").reshape(2, 3), 5), ], ) - def test_stickbreakingweights_logp(self, value, alpha, K, logp): - with Model() as model: + def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf): + stickbreakingweights_logpdf = np.vectorize( + _compile_stickbreakingweights_logpdf, signature="(n),(),()->()" + ) + value = pm.StickBreakingWeights.dist(alpha, K).eval() + with Model(): sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) pt = {"sbw": value} assert_almost_equal( pm.logp(sbw, value).eval(), - logp, + stickbreakingweights_logpdf(value, alpha, K), decimal=select_by_precision(float64=6, float32=2), err_msg=str(pt), ) @@ -2313,37 +2320,6 @@ def test_stickbreakingweights_invalid(self): assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf - @pytest.mark.parametrize( - "value, alpha, K", - [ - (np.array([5, 4, 3, 2, 1]) / 15, [0.5, 1.0, 2.0], 4), - ( - np.append(0.5 ** np.arange(1, 20), 0.5**20), - np.arange(1, 7, dtype="float64").reshape(2, 3), - 19, - ), - ], - ) - def test_stickbreakingweights_vectorized(self, value, alpha, K): - _value = at.vector() - _alpha = at.scalar() - _k = at.iscalar() - _logp = logp(StickBreakingWeights.dist(_alpha, _k), _value) - _stickbreakingweights_logpdf = compile_pymc([_value, _alpha, _k], _logp) - stickbreakingweights_logpdf = np.vectorize( - _stickbreakingweights_logpdf, signature="(n),(),()->()" - ) - - with Model(): - sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) - pt = {"sbw": value} - assert_almost_equal( - pm.logp(sbw, value).eval(), - stickbreakingweights_logpdf(value, alpha, K), - decimal=select_by_precision(float64=6, float32=2), - err_msg=str(pt), - ) - @aesara.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): with Model(): From fe7c2b2d708fb4b2ec006bd0b2533175f42e81d5 Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Thu, 25 Aug 2022 14:53:58 +0530 Subject: [PATCH 6/7] test logp for batched alpha --- pymc/tests/test_distributions.py | 54 +++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 5f8e09dd95a..6805891623a 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -962,6 +962,15 @@ def _compile_stickbreakingweights_logpdf(): return compile_pymc([_value, _alpha, _k], _logp) +def _stickbreakingweights_logpdf(value, alpha, k, _compile_stickbreakingweights_logpdf): + return _compile_stickbreakingweights_logpdf(value, alpha, k) + + +stickbreakingweights_logpdf = np.vectorize( + _stickbreakingweights_logpdf, signature="(n),(),(),()->()" +) + + class TestMatchesScipy: def test_uniform(self): check_logp( @@ -2289,25 +2298,27 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size): ) @pytest.mark.parametrize( - "alpha,K", + "value,alpha,K,logp", [ - (0.5, 4), - (2, 12), - (np.array([0.5, 1.0, 2.0]), 3), - (np.arange(1, 7, dtype="float64").reshape(2, 3), 5), + (np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439), + (np.tile(1, 13) / 13, 2, 12, 13.980045245672827), + (np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723), + (np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092), + ( + (np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])), + 2.5, + 3, + np.array([1.29317672, 1.50126157]), + ), ], ) - def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf): - stickbreakingweights_logpdf = np.vectorize( - _compile_stickbreakingweights_logpdf, signature="(n),(),()->()" - ) - value = pm.StickBreakingWeights.dist(alpha, K).eval() - with Model(): + def test_stickbreakingweights_logp(self, value, alpha, K, logp): + with Model() as model: sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) pt = {"sbw": value} assert_almost_equal( pm.logp(sbw, value).eval(), - stickbreakingweights_logpdf(value, alpha, K), + logp, decimal=select_by_precision(float64=6, float32=2), err_msg=str(pt), ) @@ -2320,6 +2331,25 @@ def test_stickbreakingweights_invalid(self): assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf + @pytest.mark.parametrize( + "alpha,K", + [ + (np.array([0.5, 1.0, 2.0]), 3), + (np.arange(1, 7, dtype="float64").reshape(2, 3), 5), + ], + ) + def test_stickbreakingweights_vectorized(self, alpha, K, _compile_stickbreakingweights_logpdf): + value = pm.StickBreakingWeights.dist(alpha, K).eval() + with Model(): + sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) + pt = {"sbw": value} + assert_almost_equal( + pm.logp(sbw, value).eval(), + stickbreakingweights_logpdf(value, alpha, K, _compile_stickbreakingweights_logpdf), + decimal=select_by_precision(float64=6, float32=2), + err_msg=str(pt), + ) + @aesara.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): with Model(): From 2a5df64899145cbb982db07ab77c5f14e91765b6 Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Sun, 28 Aug 2022 08:57:43 +0530 Subject: [PATCH 7/7] added test for moment --- pymc/distributions/multivariate.py | 3 ++- pymc/tests/test_distributions.py | 17 +++++----------- pymc/tests/test_distributions_moments.py | 26 ++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index e82af48ac9d..265d535b481 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2277,9 +2277,10 @@ def dist(cls, alpha, K, *args, **kwargs): return super().dist([alpha, K], **kwargs) def moment(rv, size, alpha, K): + alpha = alpha[..., np.newaxis] moment = (alpha / (1 + alpha)) ** at.arange(K) moment *= 1 / (1 + alpha) - moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1) + moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1) if not rv_size_is_none(size): moment_size = at.concatenate( [ diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 6805891623a..d0c1135bdde 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -954,21 +954,14 @@ def test_hierarchical_obs_logp(): @pytest.fixture(scope="module") -def _compile_stickbreakingweights_logpdf(): +def stickbreakingweights_logpdf(): _value = at.vector() _alpha = at.scalar() _k = at.iscalar() _logp = logp(StickBreakingWeights.dist(_alpha, _k), _value) - return compile_pymc([_value, _alpha, _k], _logp) + core_fn = compile_pymc([_value, _alpha, _k], _logp) - -def _stickbreakingweights_logpdf(value, alpha, k, _compile_stickbreakingweights_logpdf): - return _compile_stickbreakingweights_logpdf(value, alpha, k) - - -stickbreakingweights_logpdf = np.vectorize( - _stickbreakingweights_logpdf, signature="(n),(),(),()->()" -) + return np.vectorize(core_fn, signature="(n),(),()->()") class TestMatchesScipy: @@ -2338,14 +2331,14 @@ def test_stickbreakingweights_invalid(self): (np.arange(1, 7, dtype="float64").reshape(2, 3), 5), ], ) - def test_stickbreakingweights_vectorized(self, alpha, K, _compile_stickbreakingweights_logpdf): + def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf): value = pm.StickBreakingWeights.dist(alpha, K).eval() with Model(): sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) pt = {"sbw": value} assert_almost_equal( pm.logp(sbw, value).eval(), - stickbreakingweights_logpdf(value, alpha, K, _compile_stickbreakingweights_logpdf), + stickbreakingweights_logpdf(value, alpha, K), decimal=select_by_precision(float64=6, float32=2), err_msg=str(pt), ) diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 75c29849bc0..fac192315ce 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected): fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5), ), ), + ( + np.array([1, 3]), + 11, + None, + np.array( + [ + np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11), + np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11), + ] + ), + ), + ( + np.array([1, 3, 5]), + 9, + (5, 3), + np.full( + shape=(5, 3, 10), + fill_value=np.array( + [ + np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9), + np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9), + np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9), + ] + ), + ), + ), ], ) def test_stickbreakingweights_moment(alpha, K, size, expected):