Skip to content

Commit

Permalink
Make Multinomial robust against batches (#4169)
Browse files Browse the repository at this point in the history
* Allow Multinomial to work with batches of n and p that have more than 2 dimensions

* Fix failing tests

* Fix the float32 errors

* Added line to release notes
  • Loading branch information
lucianopaz authored Oct 14, 2020
1 parent b31b42a commit d8bfe93
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)).
- Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129)
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)

### Documentation

Expand Down
4 changes: 0 additions & 4 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,14 +597,10 @@ def __init__(self, n, p, *args, **kwargs):
super().__init__(*args, **kwargs)

p = p / tt.sum(p, axis=-1, keepdims=True)
n = np.squeeze(n) # works also if n is a tensor

if len(self.shape) > 1:
self.n = tt.shape_padright(n)
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
elif n.ndim == 1:
self.n = tt.shape_padright(n)
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
else:
# n is a scalar, p is a 1d array
self.n = tt.as_tensor_variable(n)
Expand Down
22 changes: 22 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,28 @@ def test_multinomial_vec_2d_p(self):
decimal=4,
)

def test_batch_multinomial(self):
n = 10
vals = np.zeros((4, 5, 3), dtype="int32")
p = np.zeros_like(vals, dtype=theano.config.floatX)
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
np.put_along_axis(vals, inds, n, axis=-1)
np.put_along_axis(p, inds, 1, axis=-1)

dist = Multinomial.dist(n=n, p=p, shape=vals.shape)
value = tt.tensor3(dtype="int32")
value.tag.test_value = np.zeros_like(vals, dtype="int32")
logp = tt.exp(dist.logp(value))
f = theano.function(inputs=[value], outputs=logp)
assert_almost_equal(
f(vals),
np.ones(vals.shape[:-1] + (1,)),
decimal=select_by_precision(float64=6, float32=3),
)

sample = dist.random(size=2)
assert_allclose(sample, np.stack([vals, vals], axis=0))

def test_categorical_bounds(self):
with Model():
x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))
Expand Down

0 comments on commit d8bfe93

Please sign in to comment.