From afdef12e851f0c5d21f80208d387d9c4918d0fb7 Mon Sep 17 00:00:00 2001 From: hashcode-ankit <2698ankitsharma@gmail.com> Date: Sun, 27 Apr 2025 13:53:33 +0530 Subject: [PATCH 1/4] chore: adding NoneConst and None check --- pymc/logprob/mixture.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index ce6a11d208..1a5a9ac1e2 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -289,7 +289,14 @@ def find_measurable_index_mixture(fgraph, node): # We don't support (non-scalar) integer array indexing as it can pick repeated values, # but the Mixture logprob assumes all mixture values are independent if any( - indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0 + ( + isinstance(indices, (type(NoneConst), type(None))) + or + ( + indices.dtype.startswith("int") and + sum(1 - b for b in indices.type.broadcastable) > 0 + ) + ) for indices in mixing_indices if not isinstance(indices, SliceConstant) ): From 8b66a9abd5d4ace5bfe8c5b0ce5ad5d15eb7dab2 Mon Sep 17 00:00:00 2001 From: hashcode-ankit <2698ankitsharma@gmail.com> Date: Sun, 27 Apr 2025 13:57:11 +0530 Subject: [PATCH 2/4] chore: fixing pre-commit --- pymc/logprob/mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 1a5a9ac1e2..0ef6a7cd5e 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -290,7 +290,7 @@ def find_measurable_index_mixture(fgraph, node): # but the Mixture logprob assumes all mixture values are independent if any( ( - isinstance(indices, (type(NoneConst), type(None))) + isinstance(indices, (type(NoneConst) | type(None))) or ( indices.dtype.startswith("int") and From 145d28ded1927a0f38f1c858b5a10e66028a8fd4 Mon Sep 17 00:00:00 2001 From: hashcode-ankit <2698ankitsharma@gmail.com> Date: Sun, 27 Apr 2025 14:08:04 +0530 Subject: [PATCH 3/4] chore: proper fix of pre-commit --- pymc/logprob/mixture.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 0ef6a7cd5e..9a0cfacc37 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -291,10 +291,9 @@ def find_measurable_index_mixture(fgraph, node): if any( ( isinstance(indices, (type(NoneConst) | type(None))) - or - ( - indices.dtype.startswith("int") and - sum(1 - b for b in indices.type.broadcastable) > 0 + or ( + indices.dtype.startswith("int") + and sum(1 - b for b in indices.type.broadcastable) > 0 ) ) for indices in mixing_indices From 675231797e499a6450477670626fbafcd4a91172 Mon Sep 17 00:00:00 2001 From: hashcode-ankit <2698ankitsharma@gmail.com> Date: Mon, 28 Apr 2025 18:06:38 +0530 Subject: [PATCH 4/4] chore: as per suggestion of @ricardoV94 --- pymc/logprob/mixture.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 9a0cfacc37..3709ef756e 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -62,7 +62,7 @@ is_basic_idx, ) from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType +from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import ( @@ -290,14 +290,11 @@ def find_measurable_index_mixture(fgraph, node): # but the Mixture logprob assumes all mixture values are independent if any( ( - isinstance(indices, (type(NoneConst) | type(None))) - or ( - indices.dtype.startswith("int") - and sum(1 - b for b in indices.type.broadcastable) > 0 - ) + isinstance(indices, TensorVariable) + and indices.dtype.startswith("int") + and any(not b for b in indices.type.broadcastable) ) for indices in mixing_indices - if not isinstance(indices, SliceConstant) ): return None