File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 8585 inc_subtensor ,
8686 indices_from_subtensor ,
8787)
88- from pytensor .tensor .type import TensorType
88+ from pytensor .tensor .type import TensorType , integer_dtypes
8989from pytensor .tensor .type_other import NoneTypeT , SliceConstant , SliceType
9090from pytensor .tensor .variable import TensorConstant , TensorVariable
9191
@@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):
19811981
19821982 if any (
19831983 (
1984- (isinstance (idx .type , TensorType ) and idx .type .dtype . startswith ( "int" ) )
1984+ (isinstance (idx .type , TensorType ) and idx .type .dtype in integer_dtypes )
19851985 or isinstance (idx .type , NoneTypeT )
19861986 )
19871987 for idx in idxs
@@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node):
20522052 int_idxs = [
20532053 (i , idx )
20542054 for i , idx in enumerate (idxs )
2055- if (isinstance (idx .type , TensorType ) and idx .dtype . startswith ( "int" ) )
2055+ if (isinstance (idx .type , TensorType ) and idx .dtype in integer_dtypes )
20562056 ]
20572057
20582058 if len (int_idxs ) != 1 :
You can’t perform that action at this time.
0 commit comments