Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with sampling of PartialObservedRVs #7071

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 20, 2023

Description

There was issue with PartialObservedRVs in that we used the same RNG in two places: in one of the masked RVs and in the shape of the empty tensor where the two masked RVs are stored as a deterministic.

This would further raise an error in JAX when computing the joined deterministic, because we don't allow RNGs in jaxified graphs.

Constant folding will ensure the shape graph will be introduced instead of the default ShapeOp(RV) so the RNG shouldn't be part of that deterministic graph.

Related Issue

Issue reported in https://discourse.pymc.io/t/shared-randomtype-issue-using-nuts-numpyro-w-r-t-data-containing-nans/13503?u=ricardov94

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7071.org.readthedocs.build/en/7071/

Not doing this would lead to no default updates for draws of such RVs, due to the same RNG being reused twice.
@ricardoV94 ricardoV94 changed the title Constant fold original RV.shape in graph of joined PartialObservedRV Fix issue with sampling of PartialObservedRVs Dec 20, 2023
Copy link

codecov bot commented Dec 20, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (35cd657) 92.20% compared to head (24527c2) 92.20%.
Report is 3 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7071   +/-   ##
=======================================
  Coverage   92.20%   92.20%           
=======================================
  Files         101      101           
  Lines       16895    16901    +6     
=======================================
+ Hits        15578    15584    +6     
  Misses       1317     1317           
Files Coverage Δ
pymc/distributions/distribution.py 96.26% <100.00%> (+<0.01%) ⬆️
pymc/distributions/multivariate.py 93.53% <100.00%> (+0.04%) ⬆️
pymc/logprob/transform_value.py 93.75% <100.00%> (ø)
pymc/pytensorf.py 91.26% <100.00%> (-0.03%) ⬇️

@ricardoV94 ricardoV94 requested a review from twiecki December 20, 2023 15:59
@ricardoV94 ricardoV94 merged commit 72534c7 into pymc-devs:main Dec 21, 2023
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants