Skip to content

Commit

Permalink
Use more strict get_scalar_constant_value when the input must be a …
Browse files Browse the repository at this point in the history
…scalar
  • Loading branch information
ricardoV94 committed Jan 13, 2025
1 parent 21523e7 commit 5b05713
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 66 deletions.
6 changes: 3 additions & 3 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Split,
TensorFromScalar,
Tri,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
Expand Down Expand Up @@ -103,7 +103,7 @@ def join(axis, *tensors):
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_underlying_scalar_constant_value(axis)
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
Expand All @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try:
constant_splits = np.array(
[
get_underlying_scalar_constant_value(splits[i])
get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def wrap_into_list(x):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
except NotScalarConstantError:
n_fixed_steps = None

Expand Down
5 changes: 2 additions & 3 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
Alloc,
AllocEmpty,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
Expand Down Expand Up @@ -1976,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes):

nsteps = node.inputs[0]
try:
nsteps = int(get_underlying_scalar_constant_value(nsteps))
nsteps = int(get_scalar_constant_value(nsteps))
except NotScalarConstantError:
pass

rep_nsteps = rep_node.inputs[0]
try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
pass

Expand Down
8 changes: 4 additions & 4 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,7 @@ def do_constant_folding(self, fgraph, node):
@_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var):
try:
return get_underlying_scalar_constant_value(var.owner.inputs[1])
return get_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")

Expand Down Expand Up @@ -2509,7 +2509,7 @@ def make_node(self, axis, *tensors):

if not isinstance(axis, int):
try:
axis = int(get_underlying_scalar_constant_value(axis))
axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
pass

Expand Down Expand Up @@ -2753,7 +2753,7 @@ def infer_shape(self, fgraph, node, ishapes):
def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs
try:
axis = get_underlying_scalar_constant_value(axis)
axis = get_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError:
Expand Down Expand Up @@ -4146,7 +4146,7 @@ def make_node(self, a, choices):
static_out_shape = ()
for s in out_shape:
try:
s_val = get_underlying_scalar_constant_value(s)
s_val = get_scalar_constant_value(s)
except (NotScalarConstantError, AttributeError):
s_val = None

Expand Down
16 changes: 6 additions & 10 deletions pytensor/tensor/conv/abstract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytensor.raise_op import Assert
from pytensor.tensor.basic import (
as_tensor_variable,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.variable import TensorConstant, TensorVariable
Expand Down Expand Up @@ -497,8 +497,8 @@ def check_dim(given, computed):
if given is None or computed is None:
return True
try:
given = get_underlying_scalar_constant_value(given)
computed = get_underlying_scalar_constant_value(computed)
given = get_scalar_constant_value(given)
computed = get_scalar_constant_value(computed)
return int(given) == int(computed)
except NotScalarConstantError:
# no answer possible, accept for now
Expand Down Expand Up @@ -534,7 +534,7 @@ def assert_conv_shape(shape):
out_shape = []
for i, n in enumerate(shape):
try:
const_n = get_underlying_scalar_constant_value(n)
const_n = get_scalar_constant_value(n)
if i < 2:
if const_n < 0:
raise ValueError(
Expand Down Expand Up @@ -2203,9 +2203,7 @@ def __init__(
if imshp_i is not None:
# Components of imshp should be constant or ints
try:
get_underlying_scalar_constant_value(
imshp_i, only_process_constants=True
)
get_scalar_constant_value(imshp_i, only_process_constants=True)
except NotScalarConstantError:
raise ValueError(
"imshp should be None or a tuple of constant int values"
Expand All @@ -2218,9 +2216,7 @@ def __init__(
if kshp_i is not None:
# Components of kshp should be constant or ints
try:
get_underlying_scalar_constant_value(
kshp_i, only_process_constants=True
)
get_scalar_constant_value(kshp_i, only_process_constants=True)
except NotScalarConstantError:
raise ValueError(
"kshp should be None or a tuple of constant int values"
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def make_node(self, x, repeats):
out_shape = [None]
else:
try:
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
const_reps = ptb.get_scalar_constant_value(repeats)
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
Expand Down
5 changes: 2 additions & 3 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
cast,
fill,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
join,
ones_like,
register_infer_shape,
Expand Down Expand Up @@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node):
n_conds = len(node.inputs[1:])
for c in node.inputs[1:]:
try:
const = get_underlying_scalar_constant_value(c)
const = get_scalar_constant_value(c)

if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it
Expand Down Expand Up @@ -834,7 +833,7 @@ def local_join_empty(fgraph, node):
return
new_inputs = []
try:
join_idx = get_underlying_scalar_constant_value(
join_idx = get_scalar_constant_value(
node.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
Expand Down
30 changes: 14 additions & 16 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node):

x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass

try:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
replace = (
get_underlying_scalar_constant_value(
x, only_process_constants=True, raise_not_constant=False
)
== 0
or get_underlying_scalar_constant_value(
y, only_process_constants=True, raise_not_constant=False
)
== 0
)

if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
Expand Down Expand Up @@ -2111,7 +2109,7 @@ def local_add_remove_zeros(fgraph, node):
y = get_underlying_scalar_constant_value(inp)
except NotScalarConstantError:
y = inp
if np.all(y == 0.0):
if y == 0.0:
continue
new_inputs.append(inp)

Expand Down Expand Up @@ -2209,7 +2207,7 @@ def local_abs_merge(fgraph, node):
)
except NotScalarConstantError:
return False
if not (const >= 0).all():
if not const >= 0:
return False
inputs.append(i)
else:
Expand Down Expand Up @@ -2861,7 +2859,7 @@ def _is_1(expr):
"""
try:
v = get_underlying_scalar_constant_value(expr)
return np.allclose(v, 1)
return np.isclose(v, 1)
except NotScalarConstantError:
return False

Expand Down Expand Up @@ -3029,7 +3027,7 @@ def is_neg(var):
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1)
is_minus_1 = np.isclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
Expand Down
11 changes: 4 additions & 7 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
cast,
constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
register_infer_shape,
stack,
)
Expand Down Expand Up @@ -213,7 +212,7 @@ def shape_ir(self, i, r):
# Do not call make_node for test_value
s = Shape_i(i)(r)
try:
s = get_underlying_scalar_constant_value(s)
s = get_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
Expand Down Expand Up @@ -297,7 +296,7 @@ def unpack(self, s_i, var):
assert len(idx) == 1
idx = idx[0]
try:
i = get_underlying_scalar_constant_value(idx)
i = get_scalar_constant_value(idx)
except NotScalarConstantError:
pass
else:
Expand Down Expand Up @@ -452,7 +451,7 @@ def update_shape(self, r, other_r):
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
get_scalar_constant_value(
merged_shape[i],
only_process_constants=True,
raise_not_constant=False,
Expand Down Expand Up @@ -481,9 +480,7 @@ def set_shape_i(self, r, i, s_i):
or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
new_shape[idx], raise_not_constant=False
)
get_scalar_constant_value(new_shape[idx], raise_not_constant=False)
)
for idx in range(r.type.ndim)
)
Expand Down
13 changes: 4 additions & 9 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def local_useless_subtensor(fgraph, node):
if isinstance(idx.stop, int | np.integer):
length_pos_data = sys.maxsize
try:
length_pos_data = get_underlying_scalar_constant_value(
length_pos_data = get_scalar_constant_value(
length_pos, only_process_constants=True
)
except NotScalarConstantError:
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):

# get length of the indexed tensor along the first axis
try:
length = get_underlying_scalar_constant_value(
length = get_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True
)
except NotScalarConstantError:
Expand Down Expand Up @@ -1736,7 +1736,7 @@ def local_join_subtensors(fgraph, node):
axis, tensors = node.inputs[0], node.inputs[1:]

try:
axis = get_underlying_scalar_constant_value(axis)
axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
return

Expand Down Expand Up @@ -1797,12 +1797,7 @@ def local_join_subtensors(fgraph, node):
if step is None:
continue
try:
if (
get_underlying_scalar_constant_value(
step, only_process_constants=True
)
!= 1
):
if get_scalar_constant_value(step, only_process_constants=True) != 1:
return None
except NotScalarConstantError:
return None
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def make_node(self, x, *shape):
type_shape[i] = xts
elif not isinstance(s.type, NoneTypeT):
try:
type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s))
type_shape[i] = int(ptb.get_scalar_constant_value(s))
except NotScalarConstantError:
pass

Expand Down Expand Up @@ -580,7 +580,7 @@ def specify_shape(
@_get_vector_length.register(SpecifyShape) # type: ignore
def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
try:
return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item())
return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item())
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")

Expand Down Expand Up @@ -661,7 +661,7 @@ def make_node(self, x, shp):
y = shp_list[index]
y = ptb.as_tensor_variable(y)
try:
s_val = ptb.get_underlying_scalar_constant_value(y).item()
s_val = ptb.get_scalar_constant_value(y).item()
if s_val >= 0:
out_shape[index] = s_val
except NotScalarConstantError:
Expand Down
Loading

0 comments on commit 5b05713

Please sign in to comment.