From 10bcb41e8541e62e0f9805fd37bd7a8d9e6cce0e Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Fri, 6 Jun 2025 21:27:44 +0000 Subject: [PATCH 1/2] Add rewrite for `softplus(log(x)) -> log1p(x)` --- pytensor/tensor/rewriting/math.py | 9 ++++++++- tests/tensor/rewriting/test_math.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..7fd02ca406 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -400,7 +400,7 @@ def local_exp_log(fgraph, node): @register_specialize -@node_rewriter([exp, expm1]) +@node_rewriter([exp, expm1, softplus]) def local_exp_log_nan_switch(fgraph, node): # Rewrites of the kind exp(log...(x)) that require a `nan` switch x = node.inputs[0] @@ -453,6 +453,13 @@ def local_exp_log_nan_switch(fgraph, node): new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) return [new_out] + # Case for softplus(log(x)) -> log1p(x) + if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype)) + return [new_out] + @register_canonicalize @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..cfd1265bad 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2010,6 +2010,27 @@ def test_exp_softplus(self, exp_op): decimal=6, ) + def test_softplus_log(self): + # softplus(log(x)) -> log1p(x) + data_valid = np.random.random((4, 3)).astype("float32") * 2 + data_valid[0, 0] = 0 # edge case + data_invalid = data_valid - 2 + + x = fmatrix() + f = function([x], softplus(log(x)), mode=self.mode) + graph = f.maker.fgraph.toposort() + ops_graph = [ + node + for node in graph + if isinstance(node.op, Elemwise) + and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus) + ] + assert len(ops_graph) == 0 + + expected = np.log1p(data_valid) + np.testing.assert_almost_equal(f(data_valid), expected) + assert np.all(np.isnan(f(data_invalid))) + @pytest.mark.parametrize( ["nested_expression", "expected_switches"], [ From 67d432d5bf31483aeeae00449154f982f657ea4a Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Mon, 23 Jun 2025 21:26:22 +0000 Subject: [PATCH 2/2] Added log1mexp(log(x)) -> log1p(-x) and its test Also implemented tests as suggested by ricardoV94 --- pytensor/tensor/rewriting/math.py | 12 +++++-- tests/tensor/rewriting/test_math.py | 49 ++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 7fd02ca406..c796c155f5 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -64,6 +64,7 @@ log, log1mexp, log1p, + log1pexp, makeKeepDims, maximum, mul, @@ -400,7 +401,7 @@ def local_exp_log(fgraph, node): @register_specialize -@node_rewriter([exp, expm1, softplus]) +@node_rewriter([exp, expm1, log1pexp, log1mexp]) def local_exp_log_nan_switch(fgraph, node): # Rewrites of the kind exp(log...(x)) that require a `nan` switch x = node.inputs[0] @@ -453,13 +454,20 @@ def local_exp_log_nan_switch(fgraph, node): new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) return [new_out] - # Case for softplus(log(x)) -> log1p(x) + # Case for log1pexp(log(x)) -> log1p(x) (log1pexp aka softplus) if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus): x = x.owner.inputs[0] old_out = node.outputs[0] new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype)) return [new_out] + # Case for log1mexp(log(x)) -> log1p(-x) + if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Log1mexp): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), log1p(-x), np.asarray(np.nan, old_out.dtype)) + return [new_out] + @register_canonicalize @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index cfd1265bad..4080b979c9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -67,6 +67,7 @@ log, log1mexp, log1p, + log1pexp, lt, maximum, minimum, @@ -2010,27 +2011,53 @@ def test_exp_softplus(self, exp_op): decimal=6, ) - def test_softplus_log(self): - # softplus(log(x)) -> log1p(x) + def test_log1pexp_log(self): + # log1pexp(log(x)) -> log1p(x) data_valid = np.random.random((4, 3)).astype("float32") * 2 data_valid[0, 0] = 0 # edge case data_invalid = data_valid - 2 x = fmatrix() - f = function([x], softplus(log(x)), mode=self.mode) - graph = f.maker.fgraph.toposort() - ops_graph = [ - node - for node in graph - if isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus) - ] - assert len(ops_graph) == 0 + f = function([x], log1pexp(log(x)), mode=self.mode.excluding("inplace")) + assert equal_computations( + f.maker.fgraph.outputs, + [ + pt.switch( + x >= np.array([[0]], dtype=np.int8), + pt.log1p(x), + np.array([[np.nan]], dtype=np.float32), + ) + ], + ) expected = np.log1p(data_valid) np.testing.assert_almost_equal(f(data_valid), expected) assert np.all(np.isnan(f(data_invalid))) + def test_log1mexp_log(self): + # log1mexp(log(x)) -> log1p(-x) + data_valid = np.random.random((4, 3)).astype("float32") + data_valid[0, 0] = 0 # edge case + data_valid[0, 1] = 1 # another edge case + data_invalid = np.concatenate([data_valid + 1.1, data_valid - 1.1]) + + x = fmatrix() + f = function([x], log1mexp(log(x)), mode=self.mode.excluding("inplace")) + assert equal_computations( + f.maker.fgraph.outputs, + [ + pt.switch( + x >= np.array([[0]], dtype=np.int8), + pt.log1p(-x), + np.array([[np.nan]], dtype=np.float32), + ) + ], + ) + + expected = np.log1p(-data_valid) + np.testing.assert_almost_equal(f(data_valid), expected) + assert np.all(np.isnan(f(data_invalid))) + @pytest.mark.parametrize( ["nested_expression", "expected_switches"], [