-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revert size
kwarg behaviour and make it work with Ellipsis
#4667
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1986,29 +1986,33 @@ def test_multinomial_mode(self, p, n): | |
@pytest.mark.parametrize( | ||
"p, size, n", | ||
[ | ||
[[0.25, 0.25, 0.25, 0.25], (4,), 2], | ||
[[0.25, 0.25, 0.25, 0.25], (1, 4), 3], | ||
[[0.25, 0.25, 0.25, 0.25], (7,), 2], | ||
[[0.25, 0.25, 0.25, 0.25], (1, 7), 3], | ||
# 3: expect to fail | ||
# [[.25, .25, .25, .25], (10, 4)], | ||
[[0.25, 0.25, 0.25, 0.25], (10, 1, 4), 5], | ||
# [[.25, .25, .25, .25], (10, 7)], | ||
[[0.25, 0.25, 0.25, 0.25], (10, 1, 7), 5], | ||
# 5: expect to fail | ||
# [[[.25, .25, .25, .25]], (2, 4), [7, 11]], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), 13], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (1, 2, 4), [23, 29]], | ||
# [[[.25, .25, .25, .25]], (2, 5), [7, 11]], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (5, 2), 13], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (5, 7, 2), [23, 29]], | ||
[ | ||
[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], | ||
(10, 2, 4), | ||
(10, 8, 2), | ||
[31, 37], | ||
], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), [17, 19]], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (3, 2), [17, 19]], | ||
], | ||
) | ||
def test_multinomial_random(self, p, size, n): | ||
p = np.asarray(p) | ||
with Model() as model: | ||
m = Multinomial("m", n=n, p=p, size=size) | ||
|
||
assert m.eval().shape == size + p.shape | ||
# The support has length 4 in all test parametrizations! | ||
# Broadcasting of the `p` parameter does not affect the ndim_supp | ||
# of the Op, hence the broadcasted p must be included in `size`. | ||
support_shape = (p.shape[-1],) | ||
assert support_shape == (4,) | ||
assert m.eval().shape == size + support_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parametrization The broadcasting of the |
||
|
||
@pytest.mark.skip(reason="Moment calculations have not been refactored yet") | ||
def test_multinomial_mode_with_shape(self): | ||
|
@@ -2109,7 +2113,7 @@ def test_batch_multinomial(self): | |
decimal=select_by_precision(float64=6, float32=3), | ||
) | ||
|
||
dist = Multinomial.dist(n=n, p=p, size=2) | ||
dist = Multinomial.dist(n=n, p=p, size=(2, ...)) | ||
sample = dist.eval() | ||
assert_allclose(sample, np.stack([vals, vals], axis=0)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Aesara
MultinomialRV
can do broadcasting since v2.0.4.