Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for batched alpha in StickBreakingWeights #6042

Merged
merged 7 commits into from
Aug 31, 2022

Conversation

purna135
Copy link
Member

@purna135 purna135 commented Aug 9, 2022

What is this PR about?
Addressing #5383
This enables StickBreakingWeight's alpha to accept batched data (>2D), make the infer_shape work with batched data, and fix the rng_fn by broadcasting alpha to K.

Checklist

Major / Breaking Changes

  • ...

Bugfixes / New features

  • StickBreakingWeights now supports batched alpha parameters

Docs / Maintenance

  • ...

pymc/tests/test_distributions_random.py Outdated Show resolved Hide resolved
pymc/tests/test_distributions_random.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Aug 9, 2022

Codecov Report

Merging #6042 (2a5df64) into main (ad16bf4) will decrease coverage by 1.82%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6042      +/-   ##
==========================================
- Coverage   89.27%   87.44%   -1.83%     
==========================================
  Files          72       72              
  Lines       12890    12946      +56     
==========================================
- Hits        11507    11321     -186     
- Misses       1383     1625     +242     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 91.25% <100.00%> (-0.75%) ⬇️
pymc/distributions/timeseries.py 43.36% <0.00%> (-35.28%) ⬇️
pymc/model_graph.py 65.66% <0.00%> (-29.80%) ⬇️
pymc/model.py 76.14% <0.00%> (-12.06%) ⬇️
pymc/step_methods/hmc/quadpotential.py 73.76% <0.00%> (-6.94%) ⬇️
pymc/util.py 75.29% <0.00%> (-2.36%) ⬇️
pymc/distributions/discrete.py 97.65% <0.00%> (-1.57%) ⬇️
pymc/step_methods/hmc/base_hmc.py 89.76% <0.00%> (-0.79%) ⬇️
pymc/gp/gp.py 92.73% <0.00%> (-0.45%) ⬇️
... and 9 more

Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com>
@purna135 purna135 force-pushed the generalize_StickBreaking branch from ab3d6e2 to 7fde39a Compare August 9, 2022 20:22
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start. I found some issues about the interpretation of size in conjugation with batched alphas

pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/tests/test_distributions.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title allow alpha to take batched data for StickBreakingWeights Allow for batched alpha in StickBreakingWeights Aug 10, 2022
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/tests/test_distributions_random.py Outdated Show resolved Hide resolved
pymc/tests/test_distributions.py Outdated Show resolved Hide resolved
pymc/tests/test_distributions.py Outdated Show resolved Hide resolved
Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @purna135! :)

pymc/distributions/multivariate.py Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
@larryshamalama
Copy link
Member

larryshamalama commented Aug 16, 2022

Would the moment and logp methods in the distribution class need appropriate broadcasting for vector-valued alphas? I am just reading over this PR and might have missed previous discussions

@ricardoV94
Copy link
Member

Would the moment and logp methods in the distribution class need appropriate broadcasting for vector-valued alphas? I am just reading over this PR and might have missed previous discussions

We have tests for batched alpha, but not moment (we should)

Comment on lines 2296 to 2300
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
with Model() as model:
def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it acceptable to combine test_stickbreakingweights_logp and test_stickbreakingweights_vectorized in a single test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If test_stickbreakingweights_vectorized fails, it would point to shapes not being handled properly to form batches. So, lets keep them separate to have better isolation of tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I separate the test for batched alpha

@purna135
Copy link
Member Author

We have tests for batched alpha, but not moment (we should)

I need some assistance calculating expected in test_stickbreakingweights_moment.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 25, 2022

Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the moment function works for batched alpha as well. The existing tests are in here:

def test_stickbreakingweights_moment(alpha, K, size, expected):
with Model() as model:
StickBreakingWeights("x", alpha=alpha, K=K, size=size)
assert_moment_is_expected(model, expected)

Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment.

@purna135
Copy link
Member Author

Yes, I got the test for moment but I am not sure how the expected is calculated here.
Is there any equation to determine the expected?

@pytest.mark.parametrize(
"alpha, K, size, expected",
[
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),

@ricardoV94
Copy link
Member

Yes, I got the test for moment but I am not sure how the expected is calculated here. Is there any equation to determine the expected?

@pytest.mark.parametrize(
"alpha, K, size, expected",
[
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),

You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values.

@purna135
Copy link
Member Author

Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha ?

@ricardoV94
Copy link
Member

Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha ?

Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together

@ricardoV94
Copy link
Member

Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well.

Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @purna135! Just one question for myself 😅

(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
],
)
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is stickbreakingweights_logpdf passed as an argument here via the fixture sharing decorator of the function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's how fixtures work

@ricardoV94 ricardoV94 merged commit 0b191ad into pymc-devs:main Aug 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants