Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RandomVariable static type shape bug #475

Merged
merged 1 commit into from
Oct 11, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 11, 2023

This bug showed up in pymc-devs/pymc#6947

import pytensor.tensor as pt

p = pt.ones(3) / 3
x = pt.random.categorical(p=pt.stack([p, 1-p], axis=-1))
assert x.type.shape == (3,)  # AssertionError

Which would later lead to a rewrite error during compilation.

It was caused by the presence of np.int in the static shape of Join, (and corresponding np.bool in the broadcastable), which would then be overlooked by an explicit check broadcastable is False in the RandomVariable.infer_shape.

@codecov-commenter
Copy link

Codecov Report

Merging #475 (f236f96) into main (36df379) will increase coverage by 0.00%.
Report is 1 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #475   +/-   ##
=======================================
  Coverage   80.66%   80.66%           
=======================================
  Files         160      160           
  Lines       46025    46029    +4     
  Branches    11266    11268    +2     
=======================================
+ Hits        37124    37128    +4     
  Misses       6668     6668           
  Partials     2233     2233           
Files Coverage Δ
pytensor/tensor/random/op.py 96.25% <ø> (ø)
pytensor/tensor/type.py 94.52% <100.00%> (+0.04%) ⬆️

@michaelosthege michaelosthege merged commit 6834740 into pymc-devs:main Oct 11, 2023
52 checks passed
@ricardoV94 ricardoV94 changed the title Fix static type shape bug Fix RandomVariable static type shape bug Oct 12, 2023
@ricardoV94 ricardoV94 deleted the fix_static_type_shape_bug branch October 12, 2023 07:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working shape inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants