Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix measurable stack and join with interdependent inputs #6342

Merged
merged 2 commits into from
Dec 7, 2022

Conversation

ricardoV94
Copy link
Member

This fixes measurable stacks and joins that have interdependent variables.

The main loop of factorized_joint_logprob replaces all the measurable variables already seen by their values before calling the logp of a new measurable node, but it does not replace said node by its value variable, as it assumes the returned logprob expression never depends on the original measurable (stochastic) node.

However this can happen in Stack and Join whenever the base measurable RVs depend on one another. The new tests show this clearly. The solution proposed in this PR is to do the additional replacements needed in the respective logp functions.

Major / Breaking Changes

  • ...

Bugfixes / New features

  • Fix logprob of stack and join when base variables are interdependent

Docs / Maintenance

  • ...

@codecov
Copy link

codecov bot commented Nov 28, 2022

Codecov Report

Merging #6342 (ddc6b65) into main (ddc6b65) will not change coverage.
The diff coverage is n/a.

❗ Current head ddc6b65 differs from pull request most recent head 51fbdc9. Consider uploading reports for the commit 51fbdc9 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #6342   +/-   ##
=======================================
  Coverage   94.71%   94.71%           
=======================================
  Files         132      132           
  Lines       26695    26695           
=======================================
  Hits        25284    25284           
  Misses       1411     1411           

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

damn, I forgot to submit my review comments

logps = replace_rvs_by_values(
logps,
rvs_to_values=base_rvs_to_values,
rvs_to_transforms={rv: None for rv in base_rvs},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know that there are no transforms?

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transforms are applied before the RV-specific logp function is called. The value received is already the one in the constrained space.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@_logprob.register(new_op_type)
def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
"""Compute the log-likelihood graph for a `TransformedRV`.
We assume that the value variable was back-transformed to be on the natural
support of the respective random variable.
"""
(value,) = values
logprob = _logprob(rv_op, values, *inputs, **kwargs)
if use_jacobian:
assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
jacobian = op.transform.log_jac_det(original_forward_value, *inputs)
logprob += jacobian
return logprob

@ricardoV94
Copy link
Member Author

Thanks for the review @michaelosthege. Please don't merge this, as I will go through it in the hackathon

@ricardoV94 ricardoV94 force-pushed the fix_interdependent_join_stack branch 3 times, most recently from 5a9c755 to 6c225de Compare December 5, 2022 14:31
@ricardoV94 ricardoV94 force-pushed the fix_interdependent_join_stack branch from 6c225de to 51fbdc9 Compare December 5, 2022 15:10
@ricardoV94
Copy link
Member Author

CC @tomicapretto

@ricardoV94 ricardoV94 merged commit 5b6c804 into pymc-devs:main Dec 7, 2022
@ricardoV94 ricardoV94 deleted the fix_interdependent_join_stack branch June 6, 2023 02:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants