Skip to content

chore: adding NoneConst and None check #7764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
is_basic_idx,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
Expand Down Expand Up @@ -289,9 +289,12 @@ def find_measurable_index_mixture(fgraph, node):
# We don't support (non-scalar) integer array indexing as it can pick repeated values,
# but the Mixture logprob assumes all mixture values are independent
if any(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole check is a bit of a mess we should think a bit more. It could also be a non-constant slice (see if not)

Copy link
Author

@Hashcode-Ankit Hashcode-Ankit Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 by this line it could also be a non-constant slice do you mean to say that dynamic slices (non-constant slice) are also need to be skipped?

Also what is your view if we check for tensor variable and only in that case graph node get modified? (solution for mess)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check should be just something like

if any(
            (
                    isinstance(indices, TensorVariable)
                    and indices.dtype.startswith("int")
                    and any(not b for b in indices.type.broadcastable)
                )
            )
            for indices in mixing_indices
        ):

Basically rulling out potentially repeated integer indices.

Copy link
Author

@Hashcode-Ankit Hashcode-Ankit Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 are you saying to modify graph node in above condition only,
if it is true (which i understand) then we have to use all not any:

if all( // all not any
            (
                    isinstance(indices, TensorVariable)
                    and indices.dtype.startswith("int")
                    and any(not b for b in indices.type.broadcastable)
                )
            )
            for indices in mixing_indices
        ):
        // modify the nodes 
        
 else :
     None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that if any I suggested is the one leading to return None. No need for a new branch

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then in None and NoneConst cases it is going to modify node which is not recommended in issue description :

I am running below code

from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
from pytensor.tensor import TensorVariable
mixing_indices = [None, NoneConst]

if any(
   (
      isinstance(indices, TensorVariable)
      and indices.dtype.startswith("int")
      and any(not b for b in indices.type.broadcastable)
   )
   for indices in mixing_indices
  ):
  print("return none")
else:
  print("modify node")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing we should have to worry about is repeated indices, None/NoneConst don't introduce repeated indices, just a dummy dimension

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in issue:

The [line that I quoted above](https://github.com/pymc-devs/pymc/blob/main/pymc/logprob/mixture.py#L294) needs to also check if the indices are None or [NoneConst](https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/type_other.py#L132). That way, the rewrite will return None when it has a mixture of integer indexes, and slices or new axis.

Copy link
Author

@Hashcode-Ankit Hashcode-Ankit Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ricardoV94 i have done the changes, I am new to codebase (hope you understand), that is why having that much of questions.
Thanks

indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
(
isinstance(indices, TensorVariable)
and indices.dtype.startswith("int")
and any(not b for b in indices.type.broadcastable)
)
for indices in mixing_indices
if not isinstance(indices, SliceConstant)
):
return None

Expand Down