From 825f98e01a844e3c9ab3173872a92968f20aff00 Mon Sep 17 00:00:00 2001 From: lucianopaz Date: Tue, 13 Oct 2020 17:43:28 +0200 Subject: [PATCH 1/4] Allow Multinomial to work with batches of n and p that have more than 2 dimensions --- pymc3/distributions/multivariate.py | 13 ++++--------- pymc3/tests/test_distributions.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index a117211b1e..a1f98bb434 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -597,12 +597,8 @@ def __init__(self, n, p, *args, **kwargs): super().__init__(*args, **kwargs) p = p / tt.sum(p, axis=-1, keepdims=True) - n = np.squeeze(n) # works also if n is a tensor - if len(self.shape) > 1: - self.n = tt.shape_padright(n) - self.p = p if p.ndim > 1 else tt.shape_padleft(p) - elif n.ndim == 1: + if len(self.shape) >= 1: self.n = tt.shape_padright(n) self.p = p if p.ndim > 1 else tt.shape_padleft(p) else: @@ -611,10 +607,9 @@ def __init__(self, n, p, *args, **kwargs): self.p = tt.as_tensor_variable(p) self.mean = self.n * self.p - mode = tt.cast(tt.round(self.mean), "int32") - diff = self.n - tt.sum(mode, axis=-1, keepdims=True) - inc_bool_arr = tt.abs_(diff) > 0 - mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) + mode_ind = tt.argmax(self.p, axis=-1, keepdims=True) + mode = tt.zeros_like(self.mean, dtype=self.dtype) + mode = tt.inc_subtensor(mode[..., mode_ind], 1) self.mode = mode def _random(self, n, p, size=None, raw_size=None): diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 595de26a6a..9360253b86 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1447,6 +1447,28 @@ def test_multinomial_vec_2d_p(self): decimal=4, ) + def test_batch_multinomial(self): + n = 10 + vals = np.zeros((4, 5, 3)) + p = np.zeros_like(vals) + inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None] + np.put_along_axis(vals, inds, n, axis=-1) + np.put_along_axis(p, inds, 1, axis=-1) + + dist = Multinomial.dist(n=n, p=p, shape=vals.shape) + value = tt.tensor3() + value.tag.test_value = np.zeros_like(vals) + logp = tt.exp(dist.logp(value)) + f = theano.function(inputs=[value], outputs=logp) + assert_almost_equal( + f(vals), + np.ones(vals.shape[:-1] + (1,)), + decimal=select_by_precision(float64=6, float32=3), + ) + + sample = dist.random(size=2) + assert_allclose(sample, np.stack([vals, vals], axis=0)) + def test_categorical_bounds(self): with Model(): x = Categorical("x", p=np.array([0.2, 0.3, 0.5])) From 6ca644d632100465c9b39c63fb7c8d30b9c4c3dc Mon Sep 17 00:00:00 2001 From: lucianopaz Date: Wed, 14 Oct 2020 08:30:50 +0200 Subject: [PATCH 2/4] Fix failing tests --- pymc3/distributions/multivariate.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index a1f98bb434..1244fd5d12 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -598,7 +598,7 @@ def __init__(self, n, p, *args, **kwargs): p = p / tt.sum(p, axis=-1, keepdims=True) - if len(self.shape) >= 1: + if len(self.shape) > 1: self.n = tt.shape_padright(n) self.p = p if p.ndim > 1 else tt.shape_padleft(p) else: @@ -607,9 +607,10 @@ def __init__(self, n, p, *args, **kwargs): self.p = tt.as_tensor_variable(p) self.mean = self.n * self.p - mode_ind = tt.argmax(self.p, axis=-1, keepdims=True) - mode = tt.zeros_like(self.mean, dtype=self.dtype) - mode = tt.inc_subtensor(mode[..., mode_ind], 1) + mode = tt.cast(tt.round(self.mean), "int32") + diff = self.n - tt.sum(mode, axis=-1, keepdims=True) + inc_bool_arr = tt.abs_(diff) > 0 + mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) self.mode = mode def _random(self, n, p, size=None, raw_size=None): From 5f712a6341a509080eec1c79fce596b7b699ba96 Mon Sep 17 00:00:00 2001 From: lucianopaz Date: Wed, 14 Oct 2020 08:51:02 +0200 Subject: [PATCH 3/4] Fix the float32 errors --- pymc3/tests/test_distributions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 9360253b86..7b19e00049 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1449,15 +1449,15 @@ def test_multinomial_vec_2d_p(self): def test_batch_multinomial(self): n = 10 - vals = np.zeros((4, 5, 3)) - p = np.zeros_like(vals) + vals = np.zeros((4, 5, 3), dtype="int32") + p = np.zeros_like(vals, dtype=theano.config.floatX) inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None] np.put_along_axis(vals, inds, n, axis=-1) np.put_along_axis(p, inds, 1, axis=-1) dist = Multinomial.dist(n=n, p=p, shape=vals.shape) - value = tt.tensor3() - value.tag.test_value = np.zeros_like(vals) + value = tt.tensor3(dtype="int32") + value.tag.test_value = np.zeros_like(vals, dtype="int32") logp = tt.exp(dist.logp(value)) f = theano.function(inputs=[value], outputs=logp) assert_almost_equal( From 50b2e6559f83b165445fdc65045b39fa80879603 Mon Sep 17 00:00:00 2001 From: lucianopaz Date: Wed, 14 Oct 2020 08:54:01 +0200 Subject: [PATCH 4/4] Added line to release notes --- RELEASE-NOTES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 645f5a47a6..741bfa7216 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,7 @@ - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). - Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). - Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129) +- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169) ### Documentation