Skip to content

Commit

Permalink
Remove internal get_constant helper
Browse files Browse the repository at this point in the history
Fixes bug in `local_add_neg_to_sub` reported in pymc-devs#584
  • Loading branch information
ricardoV94 committed Nov 28, 2024
1 parent 0da6758 commit 24716ac
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
96 changes: 53 additions & 43 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
return consts, origconsts, nonconsts


def get_constant(v):
"""
Returns
-------
object
A numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, TensorConstant):
return v.unique_value
elif isinstance(v, Variable):
return None
else:
return v


@register_canonicalize
@register_stabilize
@node_rewriter([Dot])
Expand Down Expand Up @@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
"""
Find all constants and put them together into a single constant.
Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single
Finds all constants in orig_num and orig_denum
and puts them together into a single
constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is
removed from the numerator.
Expand All @@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
numct, denumct = [], []

for v in orig_num:
ct = get_constant(v)
if ct is not None:
if isinstance(v, TensorConstant) and v.unique_value is not None:
# We found a constant in the numerator!
# We add it to numct
numct.append(ct)
numct.append(v.unique_value)
else:
num.append(v)
for v in orig_denum:
ct = get_constant(v)
if ct is not None:
denumct.append(ct)
if isinstance(v, TensorConstant) and v.unique_value is not None:
denumct.append(v.unique_value)
else:
denum.append(v)

Expand All @@ -1050,10 +1030,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):

if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
# In that case we should only have one constant in `ct`.
assert len(ct) == 1
first_num_ct = get_constant(orig_num[0])
if first_num_ct is not None and ct[0].type.values_eq(
ct[0].data, first_num_ct
[var_ct] = ct
first_num_var = orig_num[0]
first_num_ct = (
first_num_var.unique_value
if isinstance(first_num_var, TensorConstant)
else None
)
if first_num_ct is not None and var_ct.type.values_eq(
var_ct.data, first_num_ct
):
# This is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on
Expand Down Expand Up @@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node):
return [new_out]

# Check if it is a negative constant
const = get_constant(second)
if const is not None and const < 0:
new_out = sub(first, np.abs(const))
if (
isinstance(second, TensorConstant)
and second.unique_value is not None
and second.unique_value < 0
):
new_out = sub(first, np.abs(second.data))
return [new_out]


Expand Down Expand Up @@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node):
@register_specialize
@node_rewriter([true_div])
def local_div_to_reciprocal(fgraph, node):
if np.all(get_constant(node.inputs[0]) == 1.0):
if (
get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
== 1.0
):
out = node.outputs[0]
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting
Expand All @@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node):
@register_canonicalize
@node_rewriter([pt_pow])
def local_pow_canonicalize(fgraph, node):
cst = get_constant(node.inputs[1])
cst = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
if cst == 0:
return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1:
Expand Down Expand Up @@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node):
@node_rewriter([int_div, true_div])
def local_zero_div(fgraph, node):
"""0 / x -> 0"""
if get_constant(node.inputs[0]) == 0:
if (
get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
== 0
):
ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret]
Expand All @@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node):
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
if (y is not None) and not broadcasted_by(xsym, ysym):
try:
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
except NotScalarConstantError:
return

if not broadcasted_by(xsym, ysym):
rval = None

if np.all(y == 2):
Expand Down Expand Up @@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node):
"""

# the idea here is that we have pow(x, y)
xsym, ysym = node.inputs

try:
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
except NotScalarConstantError:
return

odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)

# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
Expand All @@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node):
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
if not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
Expand Down Expand Up @@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node):
nb_neg_node += 1

# remove special case arguments of 1, -1 or 0
y = get_constant(inp)
y = get_underlying_scalar_constant_value(
inp, only_process_constants=True, raise_not_constant=False
)
if y == 1.0:
nb_cst += 1
elif y == -1.0:
Expand Down
8 changes: 5 additions & 3 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4440,11 +4440,13 @@ def test_local_add_neg_to_sub(first_negative):
assert np.allclose(f(x_test, y_test), exp)


def test_local_add_neg_to_sub_const():
@pytest.mark.parametrize("const_left", (True, False))
def test_local_add_neg_to_sub_const(const_left):
x = vector("x")
const = 5.0
const = np.full((3, 2), 5.0)
out = -const + x if const_left else x + (-const)

f = function([x], x + (-const), mode=Mode("py"))
f = function([x], out, mode=Mode("py"))

nodes = [
node.op
Expand Down

0 comments on commit 24716ac

Please sign in to comment.