Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logp derivation for division, subtraction and negation #6371

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 5, 2022

This implements measurable rewrites for the following type of graphs:

import pymc as pm

x = pm.Gamma.dist(2, 1)
y = 1 / x
pm.logp(y, 0.5).eval()

As well as:

y = 5 / x
y = -x
y = 5 - x

It works by canonicalizing such operations to the form already understood by find_measurable_transforms (only a new condition for Reciprocals was added)

(5 / x) -> 5 * reciprocal(x)
(-x) -> x * (-1)
5 - x -> 5 + (x * (-1))

These canonicalizations are only applied when MeasurableVariables are involved to minimize disruption of the underlying graph

Bugfixes / New features

  • Automatic logprob derivation for division, subtraction, and negation operations

@codecov
Copy link

codecov bot commented Dec 5, 2022

Codecov Report

Merging #6371 (a819fcf) into main (f96594b) will decrease coverage by 1.77%.
The diff coverage is 97.50%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6371      +/-   ##
==========================================
- Coverage   94.79%   93.01%   -1.78%     
==========================================
  Files         148      148              
  Lines       27488    27549      +61     
==========================================
- Hits        26058    25626     -432     
- Misses       1430     1923     +493     
Impacted Files Coverage Δ
pymc/logprob/transforms.py 97.50% <96.00%> (-0.25%) ⬇️
pymc/tests/logprob/test_transforms.py 99.45% <100.00%> (+0.04%) ⬆️
pymc/tests/step_methods/test_slicer.py 0.00% <0.00%> (-100.00%) ⬇️
pymc/tests/step_methods/hmc/test_nuts.py 0.00% <0.00%> (-100.00%) ⬇️
pymc/tests/step_methods/test_compound.py 0.00% <0.00%> (-100.00%) ⬇️
pymc/tests/step_methods/test_metropolis.py 0.00% <0.00%> (-100.00%) ⬇️
pymc/step_methods/metropolis.py 57.69% <0.00%> (-26.07%) ⬇️
pymc/backends/base.py 83.84% <0.00%> (-1.54%) ⬇️
pymc/sampling/parallel.py 87.36% <0.00%> (-0.09%) ⬇️
pymc/variational/inference.py 84.97% <0.00%> (ø)
... and 2 more

@ricardoV94 ricardoV94 force-pushed the logprob_implement_division_subtraction_logp branch from 40d889e to 8fd177f Compare December 5, 2022 14:10
pymc/logprob/transforms.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member Author

CC @tomicapretto

@ricardoV94 ricardoV94 force-pushed the logprob_implement_division_subtraction_logp branch from 8fd177f to 6959886 Compare December 7, 2022 15:20
@Armavica
Copy link
Member

Armavica commented Dec 8, 2022

Sorry for the very naive question, but I think I am still missing some pieces of the puzzle. I tried to understand what new models this will allow to write, but it looks like

with pm.Model():
    x = pm.Normal("x")
    y = pm.Normal("y", 1/x, 1)

was already possible before. Why is the 1/x working here? Am I completely off tracks?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 8, 2022

Sorry for the very naive question, but I think I am still missing some pieces of the puzzle. I tried to understand what new models this will allow to write, but it looks like

with pm.Model():
    x = pm.Normal("x")
    y = pm.Normal("y", 1/x, 1)

was already possible before. Why is the 1/x working here? Am I completely off tracks?

No, you're not completely off tracks. The difference is that between a parameter transformation (what your snippet does) a variable transformation (what my snippet does).

The difference between the two is easier if you just think about the logp you are requestiong:

some_value = 1.0

x = pm.Normal.dist()
y = pm.Normal.dist(1/x)
pm.logp(y, some_value)  # Fine before

z = 1 / x
pm.logp(z, some_value)  # Failed before, now it works!

This is more relevant for observed variables, so that you can condition on them directly. After this and #6361 you can do this:

def inverse_normal(mu, sigma, size):
  return 1 / pm.Normal.dist(mu, sigma, size=size)

with pm.Model() as m:
  x = pm.Normal("x")
  y = pm.CustomDist("y", mu, 1, random=inverse_normal, observed=data)

PyMC will now be able to derive the logp of y. Another thing that might be confusing, is that this is also not the same as pm.Normal("y", mu, 1, observed=1/data)


The way this functionality trickles down to users is that they don't need us to implement InverseNormal, InverseGamma, Inverse.... They can simply use CustomDist. Or alternatively we can offer them a helper pm.InverseDist without having to implement each logp manually. There's more context in #4530

@ricardoV94 ricardoV94 changed the title Implement division and subtraction logp derivation Implement logp derivation for division, subtraction and negation Dec 9, 2022
@ricardoV94 ricardoV94 force-pushed the logprob_implement_division_subtraction_logp branch from 6959886 to 48eb532 Compare December 14, 2022 12:52
Adds rewrite that converts divisions with measurable variables to product with reciprocals, making the reciprocal measurable transform more widely applicable.
@ricardoV94 ricardoV94 force-pushed the logprob_implement_division_subtraction_logp branch from 48eb532 to a819fcf Compare December 14, 2022 14:25
@ricardoV94 ricardoV94 requested a review from twiecki December 14, 2022 15:24
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elementary school math. I feel confident about this 👍

@ricardoV94 ricardoV94 merged commit 4c54b7d into pymc-devs:main Dec 14, 2022
@ricardoV94 ricardoV94 deleted the logprob_implement_division_subtraction_logp branch June 6, 2023 03:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants