Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Truncated Normal and Wald Distributions to V4 #4711

Merged

Conversation

matteo-pallini
Copy link
Contributor

@matteo-pallini matteo-pallini commented May 23, 2021

Port Truncated Normal and Wald to V4 as per #4686 guidelines

Still need to do/check the followings:

  • TruncatedNormal
  • Need to investigate why pymc3.tests.test_model.TestValueGradFunction.test_aesara_switch_broadcast_edge_cases_2 is failing
  • Is it fine to rewrite a new RV? according to the issue there should be one already, but I couldn't find it
  • Is it fine to pass transform as argument in dist ?
  • Is _defaultval deprecated? I haven't been able to find any use of it
  • Wald
  • Refactor as per guidelines
  • Investigate why pymc3.tests.test_distributions_random.TestWaldAlpha is failing.

@ricardoV94
Copy link
Member

Is it fine to rewrite a new RV? according to the issue there should be one already, but I couldn't find it

I might have confused it with the truncexpon

@ricardoV94 ricardoV94 mentioned this pull request May 23, 2021
26 tasks
pymc3/distributions/continuous.py Outdated Show resolved Hide resolved
pymc3/distributions/continuous.py Outdated Show resolved Hide resolved
pymc3/distributions/continuous.py Outdated Show resolved Hide resolved
lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper(
lower, upper
)
print(lower.eval())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it considered to add to the pre-commit checks also a check for print statements?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli suggested a way this could be done on the slack, following your suggestion. Would either one of you be interested in implementing this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna be pretty busy in the next two weeks, but this script should work as a local hook:

import ast
import sys

class Visitor(ast.NodeVisitor):
    def __init__(self, file):
        self.file = file
    def visit_Call(self, node: ast.Call) -> None:
        if isinstance(node.func, ast.Name) and node.func.id == 'print':
            sys.stdout.write(f'{self.file}:{node.lineno}:{node.col_offset} found print statement\n')
            sys.exit(1)

if __name__ == '__main__':
    for file in sys.argv[1:]:
        with open(file) as fd:
            content = fd.read()
        tree = ast.parse(content)
        visitor = Visitor(file)
        visitor.visit(tree)

@DRabbit17 if you wanted to submit this as a separate PR, I'll review it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to give it a stab

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2021

Is _defaultval deprecated? I haven't been able to find any use of it

Kind of, we still need to refactor the testval/ initialization point logic for V4 as discussed in #4567

@ricardoV94
Copy link
Member

Is it find to pass transform as argument in dist as per TruncatedNormal?

I am not sure what you mean. The BoundedContinuous class takes care of default transforms. It is possible for a user to specify another transform which will overwrite the default one. The TruncatedNormal.dist() initialization, on the other hand, will not do anything with transforms, they only matter for normal distributions intialized within a model.

I am not sure which of these you are referring to.

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2021

Need to investigate why pymc3.tests.test_model.TestValueGradFunction.test_aesara_switch_broadcast_edge_cases_2 is failing

That test probably needs to be slightly refactored for V4. The dlogp call might need to be tweaked. What error are you seeing?

@matteo-pallini
Copy link
Contributor Author

I am not sure what you mean. The BoundedContinuous ...
Sorry, the bullet point wasn't really helpful. I wrote it as a reference for myself and didn't consider how cryptic it was. I was referring to the 2nd case you mentioned.

I thought that the transform argument if passed to dist could have reached the rv_op call in Distribution.dist through **kwargs and that the logic to handle transform had been moved to aesara. But actually that would have not worked python-wise. So, I guess that the transform logic (and argument) will be simply removed with the refactoring.

The dlogp call might need to be tweaked. What error are you seeing?
m.dlogp([mu])({"mu": 0}) is consistently 0. I need to familiarize with the internals of Model and logp/dlogp before being able to tweak the test appropriately. Thanks for letting me know that changing the test is an option worth considering

@ricardoV94
Copy link
Member

I thought that the transform argument if passed to dist could have reached the rv_op call in Distribution.dist through **kwargs and that the logic to handle transform had been moved to aesara. But actually that would have not worked python-wise. So, I guess that the transform logic (and argument) will be simply removed with the refactoring.

Only the first argument (the list of parameters) and size/shape are ever passed to the rv_op. The other kwargs are intercepted in Distribution.__new__() to be used there or forwarded to Model.register_rv()

@matteo-pallini matteo-pallini marked this pull request as draft May 24, 2021 07:30
@ricardoV94
Copy link
Member

ricardoV94 commented May 24, 2021

@DRabbit17 I pushed a tiny change for the failing test. The issue was that we were passing the RandomVariable to the dlogp function instead of the logp "value" variable. It is an expected V3->V4 refactoring

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2021

Hi @DRabbit17, any progress on this PR?

We merged V4 into main so you will have to redirect the target of this PR.

Let me know if you need any help.

@matteo-pallini
Copy link
Contributor Author

matteo-pallini commented Jun 7, 2021

We merged V4 into main so you will have to redirect the target of this PR.

Thanks for the update and congrats!

Hi @DRabbit17, any progress on this PR?

Sorry for the lack of progress here. During the last 2-3 weeks work left me with little/no mental bandwidth. I should be able to pick it up again on the weekend. I would like to, but please feel free to re-assign the issue to someone else in case I am being a blocker, or someone else is keen to pick it earlier than that.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2021

@DRabbit17 There is no rush, just wanted to check what was your status.

Also, do I understand correctly from the PR title that you intended to refactor the Wald distribution?

@matteo-pallini
Copy link
Contributor Author

you intended to refactor the Wald distribution?

Yes

@matteo-pallini matteo-pallini force-pushed the refactor-wald-and-truncated-normal branch from 3678d71 to e5cc1f4 Compare June 13, 2021 14:07
@matteo-pallini matteo-pallini changed the base branch from v4 to main June 13, 2021 14:07
@matteo-pallini matteo-pallini force-pushed the refactor-wald-and-truncated-normal branch from e5cc1f4 to 7756579 Compare June 13, 2021 14:20
@matteo-pallini matteo-pallini force-pushed the refactor-wald-and-truncated-normal branch from 7756579 to fe86dcd Compare June 13, 2021 16:22
@matteo-pallini matteo-pallini force-pushed the refactor-wald-and-truncated-normal branch 2 times, most recently from 7fc3841 to 72a9798 Compare June 20, 2021 21:42
@ricardoV94 ricardoV94 added this to the vNext (4.0.0) milestone Jun 22, 2021
@matteo-pallini matteo-pallini force-pushed the refactor-wald-and-truncated-normal branch 2 times, most recently from b1aab03 to e892986 Compare June 26, 2021 23:38
@matteo-pallini matteo-pallini marked this pull request as ready for review June 26, 2021 23:48
@matteo-pallini matteo-pallini changed the title WIP: Port Truncated Normal and Wald Distributions to V4 Port Truncated Normal and Wald Distributions to V4 Jun 26, 2021
@ricardoV94 ricardoV94 force-pushed the refactor-wald-and-truncated-normal branch 2 times, most recently from af6a560 to 44343e3 Compare June 30, 2021 08:44
assert lower_interval.value == -1
assert upper_interval is None

def test_rich_context(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this one. Does not seem to test anything extra

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error trace you shared below is coming from that test, Originally I left it because I was expecting it to pass and it didn't, so I wanted to investigate why. Judging by this thread I would say that the warning is unrelated.

The error seems to be coming from pymc3.model.Model.set_initval. When running the Truncated with lower=None there is a mismatch in the number of dimensions between the rv_value_var (which is a scalar) and the initval (which is an array). The second is generated by initval_fn(), but its ndim is due to transform.forward(rv_var, value)`, I think. So, it may be possible that the interval is returning a wrong value due to the refactoring. But I haven't been able to replicate the issue with simpler tests, so I may be simply wrong. I may have written the test incorrectly, or there may be a bug. I think it's worth keeping the test at least until we cannot make it pass. For now I removed it though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also saw that thread about not being harmful... but it's suspicious that the warning appears just after the failure, which comes from an Aesara check that ndims did not change and the warning does not appear in the other successful runs of the same test.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 30, 2021

Seeing a weird warning in the failing jobs: https://github.com/pymc-devs/pymc3/pull/4711/checks?check_run_id=2950387511#step:7:2342

/usr/share/miniconda/envs/pymc3-dev-py37/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject

Which seems to be related to the failure just before: https://github.com/pymc-devs/pymc3/pull/4711/checks?check_run_id=2951980417#step:7:2332

          if self.ndim != data.ndim:
            raise TypeError(
>               f"Wrong number of dimensions: expected {self.ndim},"
                f" got {data.ndim} with shape {data.shape}."
            )
E           TypeError: Wrong number of dimensions: expected 0, got 1 with shape (1,).

@matteo-pallini matteo-pallini force-pushed the refactor-wald-and-truncated-normal branch from 470f1e0 to a68aee1 Compare June 30, 2021 22:52
@ricardoV94
Copy link
Member

ricardoV94 commented Jul 2, 2021

@DRabbit17 check my last commit. I removed some tests (or tests_to_run) that felt unnecessary. Let me know if you disagree. Otherwise I think this PR is ready to merge.

I tested creating a bunch of TruncatedNormals with random sizes and lower/ upper parameters in a single model and I did not find any issues with the initval like the one we were getting in that rich_context model. So I am pretty confident that it was not an issue on our side.

@matteo-pallini
Copy link
Contributor Author

matteo-pallini commented Jul 2, 2021

Otherwise I think this PR is ready to merge

Agreed, I did a very small change. sorry for dragging the PR for so long and thanks for the support (it would have been way faster for you to simply do the whole thing yourself :-) ).

So I am pretty confident that it was not an issue on our side.

our side as this PR or PyMC?

I have been trying to replicate the test failure for pymc3/tests/test_distributions_random.py::TestNestedRandom::test_TruncatedNormal locally but I haven't been able to, can you? (the same goes for the rich_context test). Are the github tests being ran through a docker container? if so, is it possible to download it?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 2, 2021

Agreed, I did a very small change. sorry for dragging the PR for so long and thanks for the support (it would have been way faster for you to simply do the whole thing yourself :-) ).

I disagree. The dynamic interval thing is something that we needed to figure out, and will be used in other places as well. It was really great that you dived in and started figuring it out.

So I am pretty confident that it was not an issue on our side.

our side as this PR or PyMC?

Both. I am pretty confident it was an issue with incompatible numpy / aesara binaries, that emerged on that specific environment.

I have been trying to replicate the test failure for pymc3/tests/test_distributions_random.py::TestNestedRandom::test_TruncatedNormal locally but I haven't been able to, can you? (the same goes for the rich_context test). Are the github tests being ran through a docker container? if so, is it possible to download it?

The TestNestedRandom seems like a weird caching issue, because we recently made tests marked as xfail strict (the tests fail if the test pass) and in this case it seems to still be surprised that the tests are passing. https://github.com/pymc-devs/pymc3/pull/4711/checks?check_run_id=2973646757#step:8:682

I can't find a explicit error message in the logs related to the "failing tests" (@michaelosthege any ideas?)

@michaelosthege
Copy link
Member

I added the mark.xfail on test_TruncatedNormal because the test used a nonexistent API, but the distribution was not yet refactored.
This PR refactored the distribution, so now the test is passing.
👉 Just remove the xfail.

@michaelosthege michaelosthege force-pushed the refactor-wald-and-truncated-normal branch from 42fef02 to 58b6158 Compare July 2, 2021 17:59
@michaelosthege
Copy link
Member

Rebased and removed that mark.xfail. Should go green now, but please check if my rebase didn't mess up anything :)

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 2, 2021

I added the mark.xfail on test_TruncatedNormal because the test used a nonexistent API, but the distribution was not yet refactored.
This PR refactored the distribution, so now the test is passing.
👉 Just remove the xfail.

I was convinced that I had removed it xD

@michaelosthege
Copy link
Member

@ricardoV94 the rebase was tricky too. You probably ended up with another tests mark.xfail earlier.

Both test failures seem to be related to the known problem with the find_MAP #4771.
Merge away?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 2, 2021

@ricardoV94 the rebase was tricky too. You probably ended up with another tests mark.xfail earlier.

Both test failures seem to be related to the known problem with the find_MAP #4771.
Merge away?

The MLE one will be adjusted in #4833. Let me check the other one quickly Yeah the second one looks another MAP issue. Does not seem related to any changes in this PR

@ricardoV94 ricardoV94 merged commit 9d90c89 into pymc-devs:main Jul 2, 2021
@ricardoV94
Copy link
Member

ricardoV94 commented Jul 2, 2021

Great work @DRabbit17 This was a fun one to crack. Looking forward to your next PR :)

@MarcoGorelli MarcoGorelli mentioned this pull request Jul 25, 2021
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants