Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeroSumNormal distribution #6121

Merged
merged 53 commits into from
Oct 7, 2022
Merged

Add ZeroSumNormal distribution #6121

merged 53 commits into from
Oct 7, 2022

Conversation

AlexAndorra
Copy link
Contributor

@AlexAndorra AlexAndorra commented Sep 12, 2022

This PR introduces the world famous ZeroSumNormal distribution, i.e a Normal distribution where one or several axes are constrained to sum to zero. By default, the last axis is constrained to sum to zero.

The zerosum_axes are always in the rightmost position, i.e zerosum_axes=2 means the two rightmost axes will be constrained to sum to zero (see examples below for more details).

⚠️ sigma has to be a scalar, to ensure the zero-sum constraint. The ability to specifiy a vector of sigma may be added in future versions.

Checklist


Examples:

COORDS = {
    "regions": ["a", "b", "c"],
     "answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
    # the zero sum axis will be 'answers'
    v = pm.ZeroSumNormal("v", dims=("regions", "answers"))

with pm.Model(coords=COORDS) as m:
    # the zero sum axes will be 'answers' and 'regions'
    v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)

with pm.Model(coords=COORDS) as m:
    # the zero sum axes will be the last two
    v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)

Major / Breaking Changes

  • ...

Bugfixes / New features

  • Implemented ZeroSumNormal distribution

Docs / Maintenance

  • ...

@AlexAndorra AlexAndorra added enhancements major Include in major changes release notes section labels Sep 12, 2022
@AlexAndorra AlexAndorra self-assigned this Sep 12, 2022

return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs)

# TODO: This is if we want ZeroSum constraint on other dists than Normal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucianopaz @aseyboldt does any of the math require the summed distribution to be a Normal or could it be something else (e.g., StudentT)?

@codecov
Copy link

codecov bot commented Sep 12, 2022

Codecov Report

Merging #6121 (5954e65) into main (e419d53) will increase coverage by 0.35%.
The diff coverage is 97.34%.

❗ Current head 5954e65 differs from pull request most recent head 3e72922. Consider uploading reports for the commit 3e72922 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6121      +/-   ##
==========================================
+ Coverage   93.05%   93.40%   +0.35%     
==========================================
  Files          91      100       +9     
  Lines       20804    22138    +1334     
==========================================
+ Hits        19360    20679    +1319     
- Misses       1444     1459      +15     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 92.24% <91.54%> (-0.07%) ⬇️
pymc/distributions/shape_utils.py 98.66% <97.14%> (-0.29%) ⬇️
pymc/tests/distributions/test_multivariate.py 99.44% <98.97%> (-0.06%) ⬇️
pymc/distributions/timeseries.py 82.56% <100.00%> (-0.96%) ⬇️
pymc/distributions/transforms.py 100.00% <100.00%> (ø)
pymc/tests/distributions/test_shape_utils.py 99.76% <100.00%> (+0.03%) ⬆️
pymc/tests/distributions/test_timeseries.py 95.45% <100.00%> (-0.34%) ⬇️
pymc/parallel_sampling.py 85.80% <0.00%> (-1.00%) ⬇️
pymc/tests/distributions/test_truncated.py 99.48% <0.00%> (-0.52%) ⬇️
pymc/data.py 80.08% <0.00%> (ø)
... and 13 more

@AlexAndorra
Copy link
Contributor Author

Thanks for the first review @ricardoV94 !
I added some tests. How do they look like?

pymc/tests/distributions/test_continuous.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_continuous.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_continuous.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_continuous.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would make more sense in multivariate even though we are treating it as a scalar as a hack for the time being.

pymc/distributions/continuous.py Outdated Show resolved Hide resolved
pymc/distributions/continuous.py Outdated Show resolved Hide resolved
@AlexAndorra
Copy link
Contributor Author

This would make more sense in multivariate even though we are treating it as a scalar as a hack for the time being.

Good point. I moved it to multivariate.py.

I also added all the tests mentioned and pushed everything. They pass locally. Let's see if they pass here 🤞

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexAndorra i left a bunch of comments all around the place. I think that I’m missing the part where the RV asserts that the zero sum axis are all negative and non repeating. I also think that you need to add some tests for the logp of the distribution and for the variance of the draws

pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/transforms.py Outdated Show resolved Hide resolved


def extend_axis(array, axis):
n = array.shape[axis] + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe add a comment here saying that this is using a householder reflection plus a projection operator to move forward from the constrained space onto the zero sum manifold. I’ll look up our notes and write something here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you find your notes @lucianopaz ?

pymc/tests/distributions/test_multivariate.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_multivariate.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_multivariate.py Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_multivariate.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_multivariate.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Add ZeroSumNormal distribution 🔥 Add ZeroSumNormal distribution Oct 5, 2022
@AlexAndorra AlexAndorra requested a review from ricardoV94 October 5, 2022 17:40
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/__init__.py Show resolved Hide resolved
pymc/distributions/transforms.py Show resolved Hide resolved
pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/transforms.py Outdated Show resolved Hide resolved
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@AlexAndorra AlexAndorra requested a review from ricardoV94 October 7, 2022 14:08
@ricardoV94 ricardoV94 removed the major Include in major changes release notes section label Oct 7, 2022
@ricardoV94 ricardoV94 merged commit 9aeb6b5 into main Oct 7, 2022
@ricardoV94 ricardoV94 deleted the add-zerosumnormal branch November 2, 2022 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants