-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement censored log-probabilities via the Clip
Op
#22
Conversation
Codecov Report
@@ Coverage Diff @@
## main #22 +/- ##
==========================================
- Coverage 94.92% 94.84% -0.08%
==========================================
Files 9 8 -1
Lines 1260 1106 -154
Branches 164 133 -31
==========================================
- Hits 1196 1049 -147
+ Misses 31 27 -4
+ Partials 33 30 -3
Continue to review full report at Codecov.
|
c5e7da1
to
46fac13
Compare
To make broadcasting work, you should be able to use something like |
That's a very interesting idea! Can it be combined with other indices (e.g. Regardless, don't make this PR conditional on such extensions. Let's get |
Tip: if you assign names to your test Also, don't forget about test values! They will cause errors to arise during graph construction (i.e. where the symbolic objects themselves are defined). That combined with |
Still couldn't fix my lb_rv = at.random.uniform(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
cens_x_rv.name = "cens_x_rv"
lb = lb_rv.type()
lb.name = "lb"
cens_x = cens_x_rv.type()
cens_x.name = "cens_x"
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp) These are the fgraphs printouts before and after the optimization phase in
That lower_bound
# uniform_rv.out
lower_bound in rv_map_feature.rv_values
# False
lower_bound.owner.tag
# scratchpad{'imported_by': ['local_dimshuffle_rv_lift']} Also that opt seems to not propagate the variable name lower_bound.name
# None |
If I set tthe flag to "warn", even a simple unform logp raises a lot of "Cannot compute test value..." for every node in the logp. Is this something we need to address? @aesara.config.change_flags(compute_test_value='warn')
def test_compute_test_value():
x_rv = at.random.uniform(-1, 1)
x = x_rv.type()
logp = joint_logprob(x_rv, {x_rv: x}) =============================== warnings summary ===============================
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{ge,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{-1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{le,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{ge,no_inplace}.0) of Op Elemwise{and_,no_inplace}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{second,no_inplace}(<TensorType(float64, scalar)>, Elemwise{neg,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{and_,no_inplace}.0) of Op Elemwise{switch,no_inplace}(Elemwise{and_,no_inplace}.0, Elemwise{second,no_inplace}.0, TensorConstant{-inf}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{second,no_inplace}(<TensorType(float64, scalar)>, Elemwise{neg,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{le,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{ge,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{-1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{ge,no_inplace}.0) of Op Elemwise{and_,no_inplace}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{and_,no_inplace}.0) of Op Elemwise{switch,no_inplace}(Elemwise{and_,no_inplace}.0, Elemwise{second,no_inplace}.0, TensorConstant{-inf}) missing default value
compute_test_value(node)
-- Docs: https://docs.pytest.org/en/stable/warnings.html
======================== 1 passed, 10 warnings in 0.97s ========================
Process finished with exit code 0
PASSED [100%] |
9dd9052
to
cfc3fbe
Compare
It looks like you need to set s test value for |
I added tests for the logcdf methods and checked whether the new opts work with The only thing missing are the broadcasting / |
3600f40
to
3300159
Compare
Is the broadcasting still the blocking issue/change here? |
No, I was trying an alternative that did not involve subclassing from RandomVariable. I'll try to get this back on board soon. |
38f6d75
to
d3fe4c0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from the use of tag.ignore_logprob
, this looks great.
Let's find a way to avoid using that feature, especially since nothing should be depending on it now and it's slated to be removed entirely.
Otherwise, if you want, submit the logcdf
additions as a separate PR and we can push those through sooner.
d3fe4c0
to
ccbfb22
Compare
aeppl/joint_logprob.py
Outdated
# Filter out missing terms of variables with ignore_logprob | ||
value_rvs = {v: k for k, v in updated_rv_values.items()} | ||
for missing in tuple(missing_value_terms): | ||
rv_of_missing = value_rvs.get(missing, None) | ||
if rv_of_missing and getattr(rv_of_missing.tag, "ignore_logprob", False): | ||
missing_value_terms.remove(missing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If/when wee abandon the ignore_logprob
this section can be removed. I had to add it for backwards compatibility with some tests in test_joint_logrob
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you added two tests in a different commit that explicitly require this functionality: test_fail_multiple_censored_single_base
and test_fail_base_and_censored_have_values
. If you remove those, nothing will depend on this functionality and the commit can be removed.
This needs to be done before merging, if only because the changes are unrelated to censoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I said something could be removed later, I meant the specific ignore_logprob flag logic, not the raising a RuntimeError if a variable is missing.
While developing the censored variables it would often fail silently and just return a graph with aesara clips unchanged and/or less terms than requested. This seems like a good way to catch such failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you don't want this type of check at the end of factorized_joint_logprob
?
It's trivial to manually do the same check in those new tests I added, but this explicit check might be valuable enough to have as a default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I said something could be removed later, I meant the specific ignore_logprob flag logic, not the raising a RuntimeError if a variable is missing.
While developing the censored variables it would often fail silently and just return a graph with aesara clips unchanged and/or less terms than requested. This seems like a good way to catch such failures.
From a simple development and design perspective, it sounds like you're addressing a testing-specific issue within a feature implementation, and that's generally not good.
Otherwise, if something is failing silently the first question is "What's failing?". Is it the factorized_joint_logprob
loop? If not, the failure should be addressed closer to where its primary logic/code resides, and that doesn't appear to be here.
Is there a reason why you don't want this type of check at the end of
factorized_joint_logprob
?It's trivial to manually do the same check in those new tests I added, but this explicit check might be valuable enough to have as a default.
The reason why I don't want these kinds of unrelated changes is that their inclusion makes a PR contingent on additional review work and discussions.
It takes extra time and effort to go through logic like this and determine its relevance, risk, etc. These are things that need to be done within issues and/or at the outset of a PR (e.g. the premise/description of a PR) in order to avoid delaying the inclusion of any agreed upon changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a simple development and design perspective, it sounds like you're addressing a testing-specific issue within a feature implementation, and that's generally not good.
What happened was that the test revealed that such a call to factorized_logprob would return with a missing logp term and zero complaints so I decided to add an explicit check there.
It was not for the sake of the test as that had already been solved and could be tested explicitly inside the test itself.
It was meant for further development when we introduce rewrites for Ops that are otherwise valid in logp graphs. It's also a conceptual obvious check for me: a user requested a dictionary of rv_values and we make sure we are returning a dictionary with a item for each original pair.
I don't mind splitting this into another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the commit but I am not very happy with the fact that this does not raise an error inside factorized_joint_logprob
:
def test_fail_base_and_censored_have_values():
"""Test failure when both base_rv and clipped_rv are given value vars"""
x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, x_rv, 1)
cens_x_rv.name = "cens_x"
x_vv = x_rv.clone()
cens_x_vv = cens_x_rv.clone()
logp_terms = factorized_joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})
assert cens_x_vv not in logp_terms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, why should the conditions be checked and the error be raised in factorized_joint_logprob
specifically?
The two terms involved are very specific to the censored variable logic, and it looks like the error could've been initiated in find_censored_rvs
—i.e. where all the relevant terms are identified and used. This approach could also short-circuit all the unnecessary down-stream logic, no?
You already have a warning there to that effect, so what do we gain by having an exception in factorized_joint_logprob
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it should be the rewrite responsibility to raise a failure. It may be there is another rewrite (e.g., added by users of the library) that can handle the conversion.
The bigger problem is that nothing happens if you ask for a graph that we don't know how to handle. That is not specific to censoredRVs, we just haven't tested it. For instance, this snippet does not complain at all:
import aesara.tensor as at
import aeppl
x_rv = at.random.normal(name='x')
y_rv = at.cos(x_rv)
x_vv = x_rv.clone()
y_vv = y_rv.clone()
logprob_dict = aeppl.factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
logprob_dict
# {x: x_logprob}
This snippet would be more realistic about what a user may try but is now failing for a different reason #87
logprob_dict = aeppl.factorized_joint_logprob({y_rv: y_vv})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bigger problem is that nothing happens if you ask for a graph that we don't know how to handle. That is not specific to censoredRVs, we just haven't tested it. For instance, this snippet does not complain at all:
If our rewrites don't know how to handle something, that's not necessarily a problem. As a matter of fact, we expect that they won't know how to handle more things than they do.
The assumption underlying your statements and example seems to be that you know what should be done relative to specific rewrites, and this is what makes it reasonable to handle rewrite-relevant errors in the rewrite logic.
In other words, if you know an error/warning should be raised because a value variable specification is redundant, you only really know that because you also know that there's a specific rewrite that determines which value variables are and aren't relevant.
Otherwise, a generic warning for "unused" variable/value mappings is simply an interface choice that might help inform people of issues elsewhere and/or bad assumptions (e.g. that the resulting graph will depend on certain terms), but that's all.
All this relates directly to #85.
ccbfb22
to
d760200
Compare
d760200
to
41aa447
Compare
aeppl/joint_logprob.py
Outdated
# Filter out missing terms of variables with ignore_logprob | ||
value_rvs = {v: k for k, v in updated_rv_values.items()} | ||
for missing in tuple(missing_value_terms): | ||
rv_of_missing = value_rvs.get(missing, None) | ||
if rv_of_missing and getattr(rv_of_missing.tag, "ignore_logprob", False): | ||
missing_value_terms.remove(missing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you added two tests in a different commit that explicitly require this functionality: test_fail_multiple_censored_single_base
and test_fail_base_and_censored_have_values
. If you remove those, nothing will depend on this functionality and the commit can be removed.
This needs to be done before merging, if only because the changes are unrelated to censoring.
41aa447
to
6a16fb6
Compare
acf3615
to
d7745ac
Compare
d7745ac
to
958ec44
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the same tests as before, but change the pytest.raises
to look for the warnings instead.
958ec44
to
248d9f5
Compare
248d9f5
to
4830635
Compare
Clip
Op
This PR implements logprob for censored (clipped) RVs.
I placed the new methods and tests inside
truncation.py
, expecting this file wil also contain the methods for truncated RVs in the future.Some things are still not working well / missing:
Canonicalize set_subtensors to clipWill do in another PRx[x>ub] = ub -> clip(x, x, ub)
logcdf
methodsCompute test values for new nodesSeems to not be necessary, tests pass withcompute_test_value="raise"
Explore if CensoredRVs should be created even when they don't have a direct value variable (e.g, so that they can work as input to other derivedRVs)Postponed