Skip to content

Commit

Permalink
Allow Multinomial to work with batches of n and p that have more than…
Browse files Browse the repository at this point in the history
… 2 dimensions
  • Loading branch information
lucianopaz committed Oct 13, 2020
1 parent 0bd2d65 commit bc67073
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
13 changes: 4 additions & 9 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,8 @@ def __init__(self, n, p, *args, **kwargs):
super().__init__(*args, **kwargs)

p = p / tt.sum(p, axis=-1, keepdims=True)
n = np.squeeze(n) # works also if n is a tensor

if len(self.shape) > 1:
self.n = tt.shape_padright(n)
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
elif n.ndim == 1:
if len(self.shape) >= 1:
self.n = tt.shape_padright(n)
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
else:
Expand All @@ -611,10 +607,9 @@ def __init__(self, n, p, *args, **kwargs):
self.p = tt.as_tensor_variable(p)

self.mean = self.n * self.p
mode = tt.cast(tt.round(self.mean), "int32")
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = tt.abs_(diff) > 0
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
mode_ind = tt.argmax(self.p, axis=-1, keepdims=True)
mode = tt.zeros_like(self.mean, dtype=self.dtype)
mode = tt.inc_subtensor(mode[..., mode_ind], 1)
self.mode = mode

def _random(self, n, p, size=None, raw_size=None):
Expand Down
22 changes: 22 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,28 @@ def test_multinomial_vec_2d_p(self):
decimal=4,
)

def test_batch_multinomial(self):
n = 10
vals = np.zeros((4, 5, 3))
p = np.zeros_like(vals)
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
np.put_along_axis(vals, inds, n, axis=-1)
np.put_along_axis(p, inds, 1, axis=-1)

dist = Multinomial.dist(n=n, p=p, shape=vals.shape)
value = tt.tensor3()
value.tag.test_value = np.zeros_like(vals)
logp = tt.exp(dist.logp(value))
f = theano.function(inputs=[value], outputs=logp)
assert_almost_equal(
f(vals),
np.ones(vals.shape[:-1] + (1,)),
decimal=select_by_precision(float64=6, float32=3),
)

sample = dist.random(size=2)
assert_allclose(sample, np.stack([vals, vals], axis=0))

def test_categorical_bounds(self):
with Model():
x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))
Expand Down

0 comments on commit bc67073

Please sign in to comment.