From 24716ac6a97720d365df7150ae206304b5bef7ca Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 17:53:10 +0200 Subject: [PATCH] Remove internal `get_constant helper` Fixes bug in `local_add_neg_to_sub` reported in https://github.com/pymc-devs/pytensor/issues/584 --- pytensor/tensor/rewriting/math.py | 96 ++++++++++++++++------------- tests/tensor/rewriting/test_math.py | 8 ++- 2 files changed, 58 insertions(+), 46 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e0303e935e..f36a58fcc3 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): return consts, origconsts, nonconsts -def get_constant(v): - """ - - Returns - ------- - object - A numeric constant if v is a Constant or, well, a - numeric constant. If v is a plain Variable, returns None. - - """ - if isinstance(v, TensorConstant): - return v.unique_value - elif isinstance(v, Variable): - return None - else: - return v - - @register_canonicalize @register_stabilize @node_rewriter([Dot]) @@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): """ Find all constants and put them together into a single constant. - Finds all constants in orig_num and orig_denum (using - get_constant) and puts them together into a single + Finds all constants in orig_num and orig_denum + and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator. @@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): numct, denumct = [], [] for v in orig_num: - ct = get_constant(v) - if ct is not None: + if isinstance(v, TensorConstant) and v.unique_value is not None: # We found a constant in the numerator! # We add it to numct - numct.append(ct) + numct.append(v.unique_value) else: num.append(v) for v in orig_denum: - ct = get_constant(v) - if ct is not None: - denumct.append(ct) + if isinstance(v, TensorConstant) and v.unique_value is not None: + denumct.append(v.unique_value) else: denum.append(v) @@ -1050,10 +1030,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: # In that case we should only have one constant in `ct`. - assert len(ct) == 1 - first_num_ct = get_constant(orig_num[0]) - if first_num_ct is not None and ct[0].type.values_eq( - ct[0].data, first_num_ct + [var_ct] = ct + first_num_var = orig_num[0] + first_num_ct = ( + first_num_var.unique_value + if isinstance(first_num_var, TensorConstant) + else None + ) + if first_num_ct is not None and var_ct.type.values_eq( + var_ct.data, first_num_ct ): # This is an important trick :( if it so happens that: # * there's exactly one constant on the numerator and none on @@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node): return [new_out] # Check if it is a negative constant - const = get_constant(second) - if const is not None and const < 0: - new_out = sub(first, np.abs(const)) + if ( + isinstance(second, TensorConstant) + and second.unique_value is not None + and second.unique_value < 0 + ): + new_out = sub(first, np.abs(second.data)) return [new_out] @@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node): @register_specialize @node_rewriter([true_div]) def local_div_to_reciprocal(fgraph, node): - if np.all(get_constant(node.inputs[0]) == 1.0): + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 1.0 + ): out = node.outputs[0] new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) # The ones could have forced upcasting @@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node): @register_canonicalize @node_rewriter([pt_pow]) def local_pow_canonicalize(fgraph, node): - cst = get_constant(node.inputs[1]) + cst = get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) if cst == 0: return [alloc_like(1, node.outputs[0], fgraph)] if cst == 1: @@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node): @node_rewriter([int_div, true_div]) def local_zero_div(fgraph, node): """0 / x -> 0""" - if get_constant(node.inputs[0]) == 0: + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 0 + ): ret = alloc_like(0, node.outputs[0], fgraph) ret.tag.values_eq_approx = values_eq_approx_remove_nan return [ret] @@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node): odtype = node.outputs[0].dtype xsym = node.inputs[0] ysym = node.inputs[1] - y = get_constant(ysym) - if (y is not None) and not broadcasted_by(xsym, ysym): + try: + y = get_underlying_scalar_constant_value(ysym, only_process_constants=True) + except NotScalarConstantError: + return + + if not broadcasted_by(xsym, ysym): rval = None if np.all(y == 2): @@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node): """ # the idea here is that we have pow(x, y) + xsym, ysym = node.inputs + + try: + y = get_underlying_scalar_constant_value(ysym, only_process_constants=True) + except NotScalarConstantError: + return + odtype = node.outputs[0].dtype - xsym = node.inputs[0] - ysym = node.inputs[1] - y = get_constant(ysym) # the next line is needed to fix a strange case that I don't # know how to make a separate test. @@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node): y = y[0] except IndexError: pass - if (y is not None) and not broadcasted_by(xsym, ysym): + if not broadcasted_by(xsym, ysym): rval = None # 512 is too small for the cpu and too big for some gpu! if abs(y) == int(abs(y)) and abs(y) <= 512: @@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node): nb_neg_node += 1 # remove special case arguments of 1, -1 or 0 - y = get_constant(inp) + y = get_underlying_scalar_constant_value( + inp, only_process_constants=True, raise_not_constant=False + ) if y == 1.0: nb_cst += 1 elif y == -1.0: diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 95cbacfefd..be85984995 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4440,11 +4440,13 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -def test_local_add_neg_to_sub_const(): +@pytest.mark.parametrize("const_left", (True, False)) +def test_local_add_neg_to_sub_const(const_left): x = vector("x") - const = 5.0 + const = np.full((3, 2), 5.0) + out = -const + x if const_left else x + (-const) - f = function([x], x + (-const), mode=Mode("py")) + f = function([x], out, mode=Mode("py")) nodes = [ node.op