diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 9cd9870616..2956afad02 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -18,7 +18,7 @@ Split, TensorFromScalar, Tri, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i @@ -103,7 +103,7 @@ def join(axis, *tensors): def jax_funcify_Split(op: Split, node, **kwargs): _, axis, splits = node.inputs try: - constant_axis = get_underlying_scalar_constant_value(axis) + constant_axis = get_scalar_constant_value(axis) except NotScalarConstantError: constant_axis = None warnings.warn( @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( [ - get_underlying_scalar_constant_value(splits[i]) + get_scalar_constant_value(splits[i]) for i in range(get_vector_length(splits)) ] ) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 8b92e60085..dcae273aef 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -484,7 +484,7 @@ def wrap_into_list(x): n_fixed_steps = int(n_steps) else: try: - n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps) + n_fixed_steps = pt.get_scalar_constant_value(n_steps) except NotScalarConstantError: n_fixed_steps = None diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index f2037272e1..2ba282d8d6 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -55,7 +55,6 @@ Alloc, AllocEmpty, get_scalar_constant_value, - get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -1976,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes): nsteps = node.inputs[0] try: - nsteps = int(get_underlying_scalar_constant_value(nsteps)) + nsteps = int(get_scalar_constant_value(nsteps)) except NotScalarConstantError: pass rep_nsteps = rep_node.inputs[0] try: - rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) + rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) except NotScalarConstantError: pass diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index dbe891c902..401642ddb9 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1808,7 +1808,7 @@ def do_constant_folding(self, fgraph, node): @_get_vector_length.register(Alloc) def _get_vector_length_Alloc(var_inst, var): try: - return get_underlying_scalar_constant_value(var.owner.inputs[1]) + return get_scalar_constant_value(var.owner.inputs[1]) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -2509,7 +2509,7 @@ def make_node(self, axis, *tensors): if not isinstance(axis, int): try: - axis = int(get_underlying_scalar_constant_value(axis)) + axis = int(get_scalar_constant_value(axis)) except NotScalarConstantError: pass @@ -2753,7 +2753,7 @@ def infer_shape(self, fgraph, node, ishapes): def _get_vector_length_Join(op, var): axis, *arrays = var.owner.inputs try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays) except NotScalarConstantError: @@ -4146,7 +4146,7 @@ def make_node(self, a, choices): static_out_shape = () for s in out_shape: try: - s_val = get_underlying_scalar_constant_value(s) + s_val = get_scalar_constant_value(s) except (NotScalarConstantError, AttributeError): s_val = None diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 0addd2b5f0..d1dfe44b90 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -25,7 +25,7 @@ from pytensor.raise_op import Assert from pytensor.tensor.basic import ( as_tensor_variable, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -497,8 +497,8 @@ def check_dim(given, computed): if given is None or computed is None: return True try: - given = get_underlying_scalar_constant_value(given) - computed = get_underlying_scalar_constant_value(computed) + given = get_scalar_constant_value(given) + computed = get_scalar_constant_value(computed) return int(given) == int(computed) except NotScalarConstantError: # no answer possible, accept for now @@ -534,7 +534,7 @@ def assert_conv_shape(shape): out_shape = [] for i, n in enumerate(shape): try: - const_n = get_underlying_scalar_constant_value(n) + const_n = get_scalar_constant_value(n) if i < 2: if const_n < 0: raise ValueError( @@ -2203,9 +2203,7 @@ def __init__( if imshp_i is not None: # Components of imshp should be constant or ints try: - get_underlying_scalar_constant_value( - imshp_i, only_process_constants=True - ) + get_scalar_constant_value(imshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "imshp should be None or a tuple of constant int values" @@ -2218,9 +2216,7 @@ def __init__( if kshp_i is not None: # Components of kshp should be constant or ints try: - get_underlying_scalar_constant_value( - kshp_i, only_process_constants=True - ) + get_scalar_constant_value(kshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "kshp should be None or a tuple of constant int values" diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 9fc6683200..fedcd32ab9 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -678,7 +678,7 @@ def make_node(self, x, repeats): out_shape = [None] else: try: - const_reps = ptb.get_underlying_scalar_constant_value(repeats) + const_reps = ptb.get_scalar_constant_value(repeats) except NotScalarConstantError: const_reps = None if const_reps == 1: diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 1d5ca138dc..59148fae3b 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -57,7 +57,6 @@ cast, fill, get_scalar_constant_value, - get_underlying_scalar_constant_value, join, ones_like, register_infer_shape, @@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node): n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: - const = get_underlying_scalar_constant_value(c) + const = get_scalar_constant_value(c) if 0 != const.ndim or const == 0: # Should we raise an error here? How to be sure it @@ -834,7 +833,7 @@ def local_join_empty(fgraph, node): return new_inputs = [] try: - join_idx = get_underlying_scalar_constant_value( + join_idx = get_scalar_constant_value( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9001b6e9c3..e0303e935e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node): x = node.inputs[0] y = node.inputs[1] - replace = False - try: - if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass - - try: - if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass + replace = ( + get_underlying_scalar_constant_value( + x, only_process_constants=True, raise_not_constant=False + ) + == 0 + or get_underlying_scalar_constant_value( + y, only_process_constants=True, raise_not_constant=False + ) + == 0 + ) if replace: constant_zero = constant(0, dtype=node.outputs[0].type.dtype) @@ -2111,7 +2109,7 @@ def local_add_remove_zeros(fgraph, node): y = get_underlying_scalar_constant_value(inp) except NotScalarConstantError: y = inp - if np.all(y == 0.0): + if y == 0.0: continue new_inputs.append(inp) @@ -2209,7 +2207,7 @@ def local_abs_merge(fgraph, node): ) except NotScalarConstantError: return False - if not (const >= 0).all(): + if not const >= 0: return False inputs.append(i) else: @@ -2861,7 +2859,7 @@ def _is_1(expr): """ try: v = get_underlying_scalar_constant_value(expr) - return np.allclose(v, 1) + return np.isclose(v, 1) except NotScalarConstantError: return False @@ -3029,7 +3027,7 @@ def is_neg(var): for idx, mul_input in enumerate(var_node.inputs): try: constant = get_underlying_scalar_constant_value(mul_input) - is_minus_1 = np.allclose(constant, -1) + is_minus_1 = np.isclose(constant, -1) except NotScalarConstantError: is_minus_1 = False if is_minus_1: diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 6401ecf896..e277772ad4 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -23,7 +23,6 @@ cast, constant, get_scalar_constant_value, - get_underlying_scalar_constant_value, register_infer_shape, stack, ) @@ -213,7 +212,7 @@ def shape_ir(self, i, r): # Do not call make_node for test_value s = Shape_i(i)(r) try: - s = get_underlying_scalar_constant_value(s) + s = get_scalar_constant_value(s) except NotScalarConstantError: pass return s @@ -297,7 +296,7 @@ def unpack(self, s_i, var): assert len(idx) == 1 idx = idx[0] try: - i = get_underlying_scalar_constant_value(idx) + i = get_scalar_constant_value(idx) except NotScalarConstantError: pass else: @@ -452,7 +451,7 @@ def update_shape(self, r, other_r): ) or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals( - get_underlying_scalar_constant_value( + get_scalar_constant_value( merged_shape[i], only_process_constants=True, raise_not_constant=False, @@ -481,9 +480,7 @@ def set_shape_i(self, r, i, s_i): or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals( - get_underlying_scalar_constant_value( - new_shape[idx], raise_not_constant=False - ) + get_scalar_constant_value(new_shape[idx], raise_not_constant=False) ) for idx in range(r.type.ndim) ) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index abf01921b1..4b824e46cf 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -999,7 +999,7 @@ def local_useless_subtensor(fgraph, node): if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize try: - length_pos_data = get_underlying_scalar_constant_value( + length_pos_data = get_scalar_constant_value( length_pos, only_process_constants=True ) except NotScalarConstantError: @@ -1064,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): # get length of the indexed tensor along the first axis try: - length = get_underlying_scalar_constant_value( + length = get_scalar_constant_value( shape_of[node.inputs[0]][0], only_process_constants=True ) except NotScalarConstantError: @@ -1736,7 +1736,7 @@ def local_join_subtensors(fgraph, node): axis, tensors = node.inputs[0], node.inputs[1:] try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) except NotScalarConstantError: return @@ -1797,12 +1797,7 @@ def local_join_subtensors(fgraph, node): if step is None: continue try: - if ( - get_underlying_scalar_constant_value( - step, only_process_constants=True - ) - != 1 - ): + if get_scalar_constant_value(step, only_process_constants=True) != 1: return None except NotScalarConstantError: return None diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index d145ef9c42..8913d6fb4d 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -428,7 +428,7 @@ def make_node(self, x, *shape): type_shape[i] = xts elif not isinstance(s.type, NoneTypeT): try: - type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s)) + type_shape[i] = int(ptb.get_scalar_constant_value(s)) except NotScalarConstantError: pass @@ -580,7 +580,7 @@ def specify_shape( @_get_vector_length.register(SpecifyShape) # type: ignore def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int: try: - return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()) + return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item()) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -661,7 +661,7 @@ def make_node(self, x, shp): y = shp_list[index] y = ptb.as_tensor_variable(y) try: - s_val = ptb.get_underlying_scalar_constant_value(y).item() + s_val = ptb.get_scalar_constant_value(y).item() if s_val >= 0: out_shape[index] = s_val except NotScalarConstantError: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index fe4d06f152..a3a81f63bd 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -29,7 +29,7 @@ from pytensor.tensor.basic import ( ScalarFromTensor, alloc, - get_underlying_scalar_constant_value, + get_scalar_constant_value, nonzero, scalar_from_tensor, ) @@ -778,7 +778,7 @@ def conv(val): return slice(conv(val.start), conv(val.stop), conv(val.step)) else: try: - return get_underlying_scalar_constant_value( + return get_scalar_constant_value( val, only_process_constants=only_process_constants, elemwise=elemwise, @@ -855,7 +855,7 @@ def extract_const(value): if value is None: return value, True try: - value = get_underlying_scalar_constant_value(value) + value = get_scalar_constant_value(value) return value, True except NotScalarConstantError: return value, False @@ -3022,17 +3022,17 @@ def _get_vector_length_Subtensor(op, var): start = ( None if indices[0].start is None - else get_underlying_scalar_constant_value(indices[0].start) + else get_scalar_constant_value(indices[0].start) ) stop = ( None if indices[0].stop is None - else get_underlying_scalar_constant_value(indices[0].stop) + else get_scalar_constant_value(indices[0].stop) ) step = ( None if indices[0].step is None - else get_underlying_scalar_constant_value(indices[0].step) + else get_scalar_constant_value(indices[0].step) ) if start == stop: