Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dirichlet multinomial (continued) #4373

Merged
merged 57 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
b7492d2
Add implementation of DM distribution.
bsmith89 Oct 1, 2019
2106f7c
Fix class name mistake.
bsmith89 Oct 2, 2019
487fc8a
Add DM dist to exported multivariate distributions.
bsmith89 Oct 2, 2019
24d7ec8
Export DirichletMultinomial in pymc3.distributions
bsmith89 Dec 7, 2019
4fbd1d9
Attempt at matching Multinomial initialization.
bsmith89 Dec 16, 2019
685a428
Add some simple tests for DM.
bsmith89 Dec 16, 2019
ad8e77e
Correctly deal with 1d n and 2d alpha.
bsmith89 Dec 16, 2019
8fa717a
Fix typo in DM random.
bsmith89 Dec 16, 2019
4db6b1c
Fix faulty tests for DM.
bsmith89 Dec 16, 2019
01d359b
Drop redundant initialization test for DM.
bsmith89 Dec 16, 2019
4892355
Add test that DM is normalized for n=1 case.
bsmith89 Dec 16, 2019
bc5f3bf
Add DM test case based on BetaBinomial.
bsmith89 Dec 16, 2019
ffa705c
Update pymc3/distributions/multivariate.py
ColCarroll Sep 19, 2020
c801ef1
- Infer shape by default (copied code from Dirichlet Distribution)
ricardoV94 Dec 22, 2020
c8921ee
- Use size information in random method
ricardoV94 Dec 22, 2020
e801568
- Restore merge accidental deletions
ricardoV94 Dec 22, 2020
3483ab5
- Underscore missing
ricardoV94 Dec 22, 2020
23ba2e4
- More merge cleaning
ricardoV94 Dec 22, 2020
fe018ec
Bring DirichletMultinomial initialization into alignment with Multino…
bsmith89 Dec 29, 2020
25fd41a
Align all DM tests with Multinomial.
bsmith89 Jan 1, 2021
28b0a62
Align DirichletMultinomial random implementation with Multinomial.
bsmith89 Jan 1, 2021
d363f96
Match DM random method to Multinomial implementation.
bsmith89 Jan 3, 2021
9b6828c
Change alpha -> a
ricardoV94 Jan 4, 2021
d438dfc
Run pre-commit
ricardoV94 Jan 4, 2021
dde5c45
Keep standard order of methods random and logp
ricardoV94 Jan 4, 2021
49b432d
Update docstrings for valid input types.
ricardoV94 Jan 4, 2021
83fbda6
Add new test to ensure DM matches BetaBinom
ricardoV94 Jan 4, 2021
9748a9d
Change DM alpha -> a in docstrings.
bsmith89 Jan 4, 2021
7b20680
Test two additional parameterization shapes in `test_dirichlet_multin…
bsmith89 Jan 4, 2021
66c83b0
Revert debugging comments.
bsmith89 Jan 4, 2021
672ef56
Revert unrelated changes.
bsmith89 Jan 4, 2021
2d5d20e
Fix minor Black inconsistency.
bsmith89 Jan 4, 2021
922515b
Drop no-longer-functional reshaping code.
bsmith89 Jan 5, 2021
aa89d0a
Assert shape of random samples is as expected.
bsmith89 Jan 5, 2021
2343004
Explicitly test random sample shapes, including batch dimensions.
bsmith89 Jan 5, 2021
a08bc51
Sort imports.
bsmith89 Jan 5, 2021
22beead
Simplify _random
ricardoV94 Jan 6, 2021
7bad831
Reorder tests more logically
ricardoV94 Jan 6, 2021
9bbddba
Refactor tests
ricardoV94 Jan 6, 2021
086459f
Require shape argument
ricardoV94 Jan 6, 2021
f8499d3
Remove unused import `to_tuple`
ricardoV94 Jan 6, 2021
1cd2a9f
Simplify logic to handle list as input for `a`
ricardoV94 Jan 6, 2021
ef00fe1
Raise ShapeError in random()
ricardoV94 Jan 10, 2021
f2ac8e9
Finish batch and repr unittests
ricardoV94 Jan 10, 2021
f5dcdc3
Add note about mode
ricardoV94 Jan 10, 2021
c4e017a
Tiny rewording
ricardoV94 Jan 10, 2021
d46dd50
Change mode to _defaultval
ricardoV94 Jan 12, 2021
3ab518d
Revert comment for Multinomial mode
ricardoV94 Jan 12, 2021
cdd6d27
Update shape check logic
ricardoV94 Jan 12, 2021
24447a4
Add DM to release notes.
bsmith89 Jan 12, 2021
c5e9b67
Merge branch 'master' into dirichlet_multinomial_fork
bsmith89 Jan 12, 2021
0bd6c3d
Minor docstring revisions as suggested by @AlexAndorra.
bsmith89 Jan 14, 2021
f919456
Revise the revision.
bsmith89 Jan 14, 2021
c082f00
Add comment clarifying bounds checking in logp()
bsmith89 Jan 14, 2021
ea0ae59
Address review suggestions
ricardoV94 Jan 15, 2021
b451967
Update `matches_beta_binomial` to take into consideration float preci…
ricardoV94 Jan 15, 2021
128d5cf
Add DM to multivariate distributions docs.
bsmith89 Jan 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from pymc3.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
from pymc3.distributions.multivariate import (
Dirichlet,
DirichletMultinomial,
KroneckerNormal,
LKJCholeskyCov,
LKJCorr,
Expand Down Expand Up @@ -154,6 +155,7 @@
"MvStudentT",
"Dirichlet",
"Multinomial",
"DirichletMultinomial",
"Wishart",
"WishartBartlett",
"LKJCholeskyCov",
Expand Down
133 changes: 133 additions & 0 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"MvStudentT",
"Dirichlet",
"Multinomial",
"DirichletMultinomial",
"Wishart",
"WishartBartlett",
"LKJCorr",
Expand Down Expand Up @@ -690,6 +691,138 @@ def logp(self, x):
)


class DirichletMultinomial(Discrete):
R"""Dirichlet Multinomial log-likelihood.

Dirichlet mixture of multinomials distribution, with a marginalized PMF.
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved

AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
.. math::

f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)}
{\Gamma(\n + \sum a_k)}
\prod_{k=1}^K
\frac{\Gamma(x_k + a_k)}
{\Gamma(x_k + 1)\Gamma(a_k)}

========== ===========================================
Support :math:`x \in \{0, 1, \ldots, n\}` such that
:math:`\sum x_i = n`
Mean :math:`n \frac{a_i}{\sum{a_k}}`
========== ===========================================

Parameters
----------
n : int or array
Total counts in each replicate. If n is an array its shape must be (N,)
with N = a.shape[0]

a : one- or two-dimensional array
Dirichlet parameter. Elements must be non-negative.
Dimension of each element of the distribution is the length
of the second dimension of *a*.
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

shape : numerical tuple
Describes shape of distribution. For example if n=array([5, 10]), and
p=array([1, 1, 1]), shape should be (2, 3).
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, n, a, shape, *args, **kwargs):
super().__init__(shape, *args, **kwargs)

if len(self.shape) > 1:
self.n = tt.shape_padright(n)
self.a = tt.as_tensor_variable(a) if np.ndim(a) > 1 else tt.shape_padleft(a)
else:
# n is a scalar, p is a 1d array
self.n = tt.as_tensor_variable(n)
self.a = tt.as_tensor_variable(a)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

p = self.a / self.a.sum(-1, keepdims=True)

self.mean = self.n * p
mode = tt.cast(tt.round(self.mean), "int32")
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = tt.abs_(diff) > 0
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
self.mode = mode

def _random(self, n, a, size=None, raw_size=None):
Sayam753 marked this conversation as resolved.
Show resolved Hide resolved
# numpy will cast dirichlet and multinomial samples to float64 by default
original_dtype = a.dtype

# Thanks to the default shape handling done in generate_values, the last
# axis of n is a dummy axis that allows it to broadcast well with `a`
n = np.broadcast_to(n, size)
a = np.broadcast_to(a, size)
n = n[..., 0]

# np.random.multinomial needs `n` to be a scalar int and `a` a
# sequence so we semi flatten them and iterate over them
n_ = n.reshape([-1])
a_ = a.reshape([-1, a.shape[-1]])
p_ = np.array([np.random.dirichlet(aa) for aa in a_])
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)])
samples = samples.reshape(a.shape)
Sayam753 marked this conversation as resolved.
Show resolved Hide resolved

# We cast back to the original dtype
return samples.astype(original_dtype)

def random(self, point=None, size=None):
"""
Draw random values from Dirichlet-Multinomial distribution.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
size: int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
n, a = draw_values([self.n, self.a], point=point, size=size)
samples = generate_samples(
self._random,
n,
a,
dist_shape=self.shape,
not_broadcast_kwargs={"raw_size": size},
size=size,
)

if size is not None:
expect_shape = (size, *self.shape)
else:
expect_shape = self.shape
assert tuple(samples.shape) == tuple(expect_shape)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this here? Is any other distribution doing the same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it is unnecessary during runtime, since we are already testing this quite a lot in the unittests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asserts are not evil, but also not necessarily the best option. (Also see https://stackoverflow.com/a/13534633)

To validate external inputs, use

if input_shape != expected_shape:
    raise ShapeError("Your fault!")

This has the advantage that a coverage check can tell you if the tests cover the exception.

If you want to validate an internal assumption and make sure that your code did not mess up, the assert is the right thing to do. It can help tremendously to debug or understand code.

In https://github.com/michaelosthege/pyrff/blob/master/pyrff/rff.py#L230-L261 I did both, because it took me weeks to understand the shapes of the code I was re-implementing there...

In this case if you're already testing it a lot maybe a comment instead of an assert is more appropriate.

Copy link
Contributor

@bsmith89 bsmith89 Jan 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case if you're already testing it a lot maybe a comment instead of an assert is more appropriate.

Yes, I think this is being sufficiently tested for reasonable inputs; I believe @ricardoV94 demonstrated already that these asserts were passed for a large variety of shapes (...even when some of the actual outputs were somewhat unintuitive).

I'm more worried about users "abusing" weirdly shaped inputs, in which case I like your explicit ShapeError to catch corner cases we didn't even think of.

On the other hand, Multinomial doesn't do this now. If we wanted to add this check I think we should do it for both distributions in parallel.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, Multinomial doesn't do this now. If we wanted to add this check I think we should do it for both distributions in parallel.

I agree with this


return samples

def logp(self, x):
a = self.a
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
n = self.n
sum_a = a.sum(axis=-1, keepdims=True)

const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
series = gammaln(x + a) - (gammaln(x + 1) + gammaln(a))
result = const + series.sum(axis=-1, keepdims=True)
return bound(
result,
tt.all(tt.ge(x, 0)),
tt.all(tt.gt(a, 0)),
tt.all(tt.ge(n, 0)),
tt.all(tt.eq(x.sum(axis=-1, keepdims=True), n)),
broadcast_conditions=False,
)
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved

def _distr_parameters_for_repr(self):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return ["n", "a"]
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved


def posdef(AA):
try:
linalg.cholesky(AA)
Expand Down
152 changes: 152 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Constant,
DensityDist,
Dirichlet,
DirichletMultinomial,
DiscreteUniform,
DiscreteWeibull,
ExGaussian,
Expand Down Expand Up @@ -265,6 +266,21 @@ def multinomial_logpdf(value, n, p):
return -inf


def dirichlet_multinomial_logpmf(value, n, a):
value, n, a = [np.asarray(x) for x in [value, n, a]]
assert value.ndim == 1
assert n.ndim == 0
assert a.shape == value.shape
gammaln = scipy.special.gammaln
if value.sum() == n and (0 <= value).all() and (value <= n).all():
sum_a = a.sum(axis=-1)
const = gammaln(n + 1) + gammaln(sum_a) - gammaln(n + sum_a)
series = gammaln(value + a) - gammaln(value + 1) - gammaln(a)
return const + series.sum(axis=-1)
else:
return -inf


def beta_mu_sigma(value, mu, sigma):
kappa = mu * (1 - mu) / sigma ** 2 - 1
if kappa > 0:
Expand Down Expand Up @@ -1724,6 +1740,142 @@ def test_batch_multinomial(self):
sample = dist.random(size=2)
assert_allclose(sample, np.stack([vals, vals], axis=0))

@pytest.mark.parametrize("n", [2, 3])
def test_dirichlet_multinomial(self, n):
self.pymc3_matches_scipy(
DirichletMultinomial,
Vector(Nat, n),
{"a": Vector(Rplus, n), "n": Nat},
dirichlet_multinomial_logpmf,
)

def test_dirichlet_multinomial_matches_beta_binomial(self):
a, b, n = 2, 1, 5
ns = np.arange(n + 1)
ns_dm = np.vstack((ns, n - ns)).T # covert ns=1 to ns_dm=[1, 4], for all ns...
bb_logp = pm.BetaBinomial.dist(n=n, alpha=a, beta=b).logp(ns).tag.test_value
dm_logp = (
pm.DirichletMultinomial.dist(n=n, a=[a, b], shape=(1, 2)).logp(ns_dm).tag.test_value
)
dm_logp = dm_logp.ravel()
assert_allclose(bb_logp, dm_logp)

@pytest.mark.parametrize(
"a, n, shape",
[
[[0.25, 0.25, 0.25, 0.25], 1, (1, 4)],
[[0.3, 0.6, 0.05, 0.05], 2, (1, 4)],
[[0.3, 0.6, 0.05, 0.05], 10, (1, 4)],
[[0.25, 0.25, 0.25, 0.25], 1, (2, 4)],
[[0.3, 0.6, 0.05, 0.05], 2, (3, 4)],
[[[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]], [1, 10], (2, 4)],
],
)
def test_dirichlet_multinomial_mode(self, a, n, shape):
a = np.asarray(a)
with Model() as model:
m = DirichletMultinomial("m", n=n, a=a, shape=shape)
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)

def test_dirichlet_multinomial_vec(self):
vals = np.array([[2, 4, 4], [3, 3, 4]])
a = np.array([0.2, 0.3, 0.5])
n = 10

with Model() as model_single:
DirichletMultinomial("m", n=n, a=a, shape=len(a))

with Model() as model_many:
DirichletMultinomial("m", n=n, a=a, shape=vals.shape)

assert_almost_equal(
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
decimal=4,
)

assert_almost_equal(
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
model_many.free_RVs[0].logp_elemwise({"m": vals}).squeeze(),
decimal=4,
)

assert_almost_equal(
sum([model_single.fastlogp({"m": val}) for val in vals]),
model_many.fastlogp({"m": vals}),
decimal=4,
)

def test_dirichlet_multinomial_vec_1d_n(self):
vals = np.array([[2, 4, 4], [4, 3, 4]])
a = np.array([0.2, 0.3, 0.5])
ns = np.array([10, 11])

with Model() as model:
DirichletMultinomial("m", n=ns, a=a, shape=vals.shape)

assert_almost_equal(
sum([dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)]),
model.fastlogp({"m": vals}),
decimal=4,
)

def test_dirichlet_multinomial_vec_1d_n_2d_a(self):
vals = np.array([[2, 4, 4], [4, 3, 4]])
as_ = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]])
ns = np.array([10, 11])

with Model() as model:
DirichletMultinomial("m", n=ns, a=as_, shape=vals.shape)

assert_almost_equal(
sum([dirichlet_multinomial_logpmf(val, n, a) for val, n, a in zip(vals, ns, as_)]),
model.fastlogp({"m": vals}),
decimal=4,
)

def test_dirichlet_multinomial_vec_2d_a(self):
vals = np.array([[2, 4, 4], [3, 3, 4]])
as_ = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]])
n = 10

with Model() as model:
DirichletMultinomial("m", n=n, a=as_, shape=vals.shape)

assert_almost_equal(
sum([dirichlet_multinomial_logpmf(val, n, a) for val, a in zip(vals, as_)]),
model.fastlogp({"m": vals}),
decimal=4,
)

def test_batch_dirichlet_multinomial(self):
# Test that DM can handle a 3d array for `a`
n = 10
# Create an almost deterministic DM by setting a to 0.001, everywehere
# except for one category / dimensions which is given the value fo 100
vals = np.zeros((4, 5, 3), dtype="int32")
a = np.zeros_like(vals, dtype=theano.config.floatX) + 0.001
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
np.put_along_axis(vals, inds, n, axis=-1)
np.put_along_axis(a, inds, 100, axis=-1)

dist = DirichletMultinomial.dist(n=n, a=a, shape=vals.shape)

# TODO: Test logp is as expected (not as simple as the Multinomial case)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
# value = tt.tensor3(dtype="int32")
# value.tag.test_value = np.zeros_like(vals, dtype="int32")
# logp = tt.exp(dist.logp(value))
# f = theano.function(inputs=[value], outputs=logp)
# assert_almost_equal(
# f(vals),
# np.ones(vals.shape[:-1] + (1,)),
# decimal=select_by_precision(float64=6, float32=3),
# )

# Samples should be equal given the almost deterministic DM
sample = dist.random(size=2)
assert_allclose(sample, np.stack([vals, vals], axis=0))

def test_categorical_bounds(self):
with Model():
x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))
Expand Down
50 changes: 50 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,56 @@ def ref_rand(size, a):
ref_rand=ref_rand,
)

def test_dirichlet_multinomial(self):
def ref_rand(size, a, n):
k = a.shape[-1]
out = np.empty((size, k), dtype=int)
for i in range(size):
p = nr.dirichlet(a)
x = nr.multinomial(n=n, pvals=p)
out[i, :] = x
return out
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

for n in [2, 3]:
pymc3_random_discrete(
pm.DirichletMultinomial,
{"a": Vector(Rplus, n), "n": Nat},
valuedomain=Vector(Nat, n),
size=1000,
ref_rand=ref_rand,
)

@pytest.mark.parametrize(
"a, shape, n",
[
[[0.25, 0.25, 0.25, 0.25], 4, 2],
[[0.25, 0.25, 0.25, 0.25], (1, 4), 3],
[[0.25, 0.25, 0.25, 0.25], (10, 4), [2] * 10],
[[0.25, 0.25, 0.25, 0.25], (10, 1, 4), 5],
[[[0.25, 0.25, 0.25, 0.25]], (2, 4), [7, 11]],
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), 13],
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (1, 2, 4), [23, 29]],
[
[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]],
(10, 2, 4),
[31, 37],
],
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), [17, 19]],
],
)
def test_dirichlet_multinomial_shapes(self, a, shape, n):
a = np.asarray(a)
with pm.Model() as model:
m = pm.DirichletMultinomial("m", n=n, a=a, shape=shape)
samp0 = m.random()
samp1 = m.random(size=1)
samp2 = m.random(size=2)

shape_ = to_tuple(shape)
assert to_tuple(samp0.shape) == shape_
assert to_tuple(samp1.shape) == (1, *shape_)
assert to_tuple(samp2.shape) == (2, *shape_)

def test_multinomial(self):
def ref_rand(size, p, n):
return nr.multinomial(pvals=p, n=n, size=size)
Expand Down