From 07872433a5cdc3d7cfe1c5913709fde9a4b8a314 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 26 Dec 2024 12:27:31 +0000 Subject: [PATCH] More robust check for multiple integer indices in numba ravel_multidimensional_idx rewrites --- pytensor/tensor/rewriting/subtensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 7ba1908e60..572d2bcab6 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -85,7 +85,7 @@ inc_subtensor, indices_from_subtensor, ) -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node): if any( ( - (isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")) + (isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes) or isinstance(idx.type, NoneTypeT) ) for idx in idxs @@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node): int_idxs = [ (i, idx) for i, idx in enumerate(idxs) - if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int")) + if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes) ] if len(int_idxs) != 1: