-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port Truncated Normal and Wald Distributions to V4 #4711
Port Truncated Normal and Wald Distributions to V4 #4711
Conversation
I might have confused it with the truncexpon |
pymc3/distributions/continuous.py
Outdated
lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper( | ||
lower, upper | ||
) | ||
print(lower.eval()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it considered to add to the pre-commit checks also a check for print statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarcoGorelli suggested a way this could be done on the slack, following your suggestion. Would either one of you be interested in implementing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm gonna be pretty busy in the next two weeks, but this script should work as a local hook:
import ast
import sys
class Visitor(ast.NodeVisitor):
def __init__(self, file):
self.file = file
def visit_Call(self, node: ast.Call) -> None:
if isinstance(node.func, ast.Name) and node.func.id == 'print':
sys.stdout.write(f'{self.file}:{node.lineno}:{node.col_offset} found print statement\n')
sys.exit(1)
if __name__ == '__main__':
for file in sys.argv[1:]:
with open(file) as fd:
content = fd.read()
tree = ast.parse(content)
visitor = Visitor(file)
visitor.visit(tree)
@DRabbit17 if you wanted to submit this as a separate PR, I'll review it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to give it a stab
Kind of, we still need to refactor the testval/ initialization point logic for V4 as discussed in #4567 |
I am not sure what you mean. The BoundedContinuous class takes care of default transforms. It is possible for a user to specify another transform which will overwrite the default one. The I am not sure which of these you are referring to. |
That test probably needs to be slightly refactored for V4. The dlogp call might need to be tweaked. What error are you seeing? |
I thought that the
|
Only the first argument (the list of parameters) and |
@DRabbit17 I pushed a tiny change for the failing test. The issue was that we were passing the RandomVariable to the dlogp function instead of the logp "value" variable. It is an expected V3->V4 refactoring |
Hi @DRabbit17, any progress on this PR? We merged Let me know if you need any help. |
Thanks for the update and congrats!
Sorry for the lack of progress here. During the last 2-3 weeks work left me with little/no mental bandwidth. I should be able to pick it up again on the weekend. I would like to, but please feel free to re-assign the issue to someone else in case I am being a blocker, or someone else is keen to pick it earlier than that. |
@DRabbit17 There is no rush, just wanted to check what was your status. Also, do I understand correctly from the PR title that you intended to refactor the |
Yes |
3678d71
to
e5cc1f4
Compare
e5cc1f4
to
7756579
Compare
7756579
to
fe86dcd
Compare
7fc3841
to
72a9798
Compare
b1aab03
to
e892986
Compare
af6a560
to
44343e3
Compare
pymc3/tests/test_distributions.py
Outdated
assert lower_interval.value == -1 | ||
assert upper_interval is None | ||
|
||
def test_rich_context(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this one. Does not seem to test anything extra
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error trace you shared below is coming from that test, Originally I left it because I was expecting it to pass and it didn't, so I wanted to investigate why. Judging by this thread I would say that the warning is unrelated.
The error seems to be coming from pymc3.model.Model.set_initval
. When running the Truncated with lower=None there is a mismatch in the number of dimensions between the rv_value_var
(which is a scalar) and the initval
(which is an array). The second is generated by initval_fn(), but its ndim is due to
transform.forward(rv_var, value)`, I think. So, it may be possible that the interval is returning a wrong value due to the refactoring. But I haven't been able to replicate the issue with simpler tests, so I may be simply wrong. I may have written the test incorrectly, or there may be a bug. I think it's worth keeping the test at least until we cannot make it pass. For now I removed it though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also saw that thread about not being harmful... but it's suspicious that the warning appears just after the failure, which comes from an Aesara check that ndims did not change and the warning does not appear in the other successful runs of the same test.
Seeing a weird warning in the failing jobs: https://github.com/pymc-devs/pymc3/pull/4711/checks?check_run_id=2950387511#step:7:2342
Which seems to be related to the failure just before: https://github.com/pymc-devs/pymc3/pull/4711/checks?check_run_id=2951980417#step:7:2332 if self.ndim != data.ndim:
raise TypeError(
> f"Wrong number of dimensions: expected {self.ndim},"
f" got {data.ndim} with shape {data.shape}."
)
E TypeError: Wrong number of dimensions: expected 0, got 1 with shape (1,). |
470f1e0
to
a68aee1
Compare
@DRabbit17 check my last commit. I removed some tests (or I tested creating a bunch of TruncatedNormals with random sizes and lower/ upper parameters in a single model and I did not find any issues with the |
Agreed, I did a very small change. sorry for dragging the PR for so long and thanks for the support (it would have been way faster for you to simply do the whole thing yourself :-) ).
our side as this PR or PyMC? I have been trying to replicate the test failure for |
I disagree. The dynamic interval thing is something that we needed to figure out, and will be used in other places as well. It was really great that you dived in and started figuring it out.
Both. I am pretty confident it was an issue with incompatible numpy / aesara binaries, that emerged on that specific environment.
The I can't find a explicit error message in the logs related to the "failing tests" (@michaelosthege any ideas?) |
I added the |
- remove upper and lower checks and check for lack of bounds in logp and intervals directly - pass default value for dist as `testval` to Distribution dist method
42fef02
to
58b6158
Compare
Rebased and removed that |
I was convinced that I had removed it xD |
@ricardoV94 the rebase was tricky too. You probably ended up with another tests Both test failures seem to be related to the known problem with the |
The MLE one will be adjusted in #4833. |
Great work @DRabbit17 This was a fun one to crack. Looking forward to your next PR :) |
Port Truncated Normal and Wald to V4 as per #4686 guidelines
Still need to do/check the followings:
pymc3.tests.test_model.TestValueGradFunction.test_aesara_switch_broadcast_edge_cases_2
is failingtransform
as argument indist
?_defaultval
deprecated? I haven't been able to find any use of itpymc3.tests.test_distributions_random.TestWaldAlpha
is failing.