-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix measurable stack and join with interdependent inputs #6342
Fix measurable stack and join with interdependent inputs #6342
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
damn, I forgot to submit my review comments
pymc/logprob/tensor.py
Outdated
logps = replace_rvs_by_values( | ||
logps, | ||
rvs_to_values=base_rvs_to_values, | ||
rvs_to_transforms={rv: None for rv in base_rvs}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we know that there are no transforms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transforms are applied before the RV-specific logp function is called. The value received is already the one in the constrained space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pymc/pymc/logprob/transforms.py
Lines 591 to 608 in ddc6b65
@_logprob.register(new_op_type) | |
def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs): | |
"""Compute the log-likelihood graph for a `TransformedRV`. | |
We assume that the value variable was back-transformed to be on the natural | |
support of the respective random variable. | |
""" | |
(value,) = values | |
logprob = _logprob(rv_op, values, *inputs, **kwargs) | |
if use_jacobian: | |
assert isinstance(value.owner.op, TransformedVariable) | |
original_forward_value = value.owner.inputs[1] | |
jacobian = op.transform.log_jac_det(original_forward_value, *inputs) | |
logprob += jacobian | |
return logprob |
Thanks for the review @michaelosthege. Please don't merge this, as I will go through it in the hackathon |
5a9c755
to
6c225de
Compare
6c225de
to
51fbdc9
Compare
This fixes measurable stacks and joins that have interdependent variables.
The main loop of
factorized_joint_logprob
replaces all the measurable variables already seen by their values before calling the logp of a new measurable node, but it does not replace said node by its value variable, as it assumes the returned logprob expression never depends on the original measurable (stochastic) node.However this can happen in Stack and Join whenever the base measurable RVs depend on one another. The new tests show this clearly. The solution proposed in this PR is to do the additional replacements needed in the respective logp functions.
Major / Breaking Changes
Bugfixes / New features
Docs / Maintenance