-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add constant_data into sample_numpyro_nuts() #5807
Add constant_data into sample_numpyro_nuts() #5807
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5807 +/- ##
==========================================
- Coverage 89.50% 89.50% -0.01%
==========================================
Files 73 73
Lines 13276 13274 -2
==========================================
- Hits 11883 11881 -2
Misses 1393 1393
|
It seems that On my local machine, when run tests with jax flag on However, when run with
It seems a down stream error of |
Also, any ideas why the |
If you rebase from main, the failing blackjax tests should go away |
Hi yes, agree. I will refactor the codes. |
a1fab39
to
034ca36
Compare
Hopefully it is alright now :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Left some suggestions to simplify / tidy up the original functions
034ca36
to
7821d13
Compare
Hi, it has been updated :) |
Hi @ricardoV94 I think it is alright for this pr to be merged :) Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but we need to test that constant_data is showing up in test_sampling_jax.py, probably in the same test that already checks for other InferenceData groups
"""If there are observations available, return them as a dictionary.""" | ||
if model is None: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need the model is none cases. Whomever calls this function should always pass a model. There's no good reason why they would call the function with None
pymc/backends/arviz.py
Outdated
def find_constants(model: Optional["Model"]) -> Dict[str, Var]: | ||
"""If there are constants available, return them as a dictionary.""" | ||
# The constant data vars must be either pm.Data or TensorConstant or SharedVariable | ||
if model is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I think it better than the previous codes now. Why we did not remove it in previous review here? But it could be good to have some defensive codes here, since we not 100% sure people always call it with a model. Any way, this is an easy fix, and just a minor thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we did not remove it in previous review here?
We just missed it then.
But it could be good to have some defensive codes here, since we not 100% sure people always call it with a model.
I don't see why you would call this function if not with a model. It just doesn't make sense. And if someone does it that's probably a mistake and the best thing is to let it naturally crash.
Adding special cases like this can make it more difficult to spot errors instead of helping.
Hi, I cannot find where these codes though? Can you point it out? I have already run these following codes and check to make sure the constant_data is shown in
|
I don't see an existing test that checks pymc/pymc/tests/test_sampling_jax.py Line 160 in ab05d44
But please make sure to set something like |
Yes, I will add the test for both Just need some more time to do that :) Thanks |
@danhphan let us know if someone should take over the rest here. This is one of currently two open items in the 4.0.1 milestone.. |
7821d13
to
6f50e37
Compare
Hi @michaelosthege Sorry I'm quite busy with other more urgent stuffs right now, so please feel free to go ahead with other steps. I will not able to do it in early this week. If no-one updates this, I will try to spend sometime this weekend to complete it. Thanks you. |
Thanks @danhphan and @michaelosthege! |
This PR add constant_data into
pymc.sampling_jax.sample_numpyro_nuts()
. It addresses the issue #5781cc @ricardoV94 :)
Let me know if it needs to be changed. Thank you.