diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index 2aeb701cf6..48df59e461 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -311,6 +311,35 @@ def local_exp_log(fgraph, node): return [exp(x)] +@register_specialize +@node_rewriter([log]) +def log_diff_exp(fgraph, node): + r""" + Rewrite that changes ``log(exp(a) - exp(b))`` to ``a + log1mexp(b - a)``. + """ + x = node.inputs[0] + + if not x.owner or not isinstance(x.owner.op, Elemwise): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + if isinstance(prev_op, aes.Sub) and isinstance(node_op, aes.Log): + a, b = x.owner.inputs + if not a.owner or not b.owner: + return + a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op + if isinstance(a_op, aes.Exp) and isinstance(b_op, aes.Exp): + a = a.owner.inputs[0] + b = b.owner.inputs[0] + new_out = add(a, log1mexp(sub(b, a))) + old_out = node.outputs[0] + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] + + @register_specialize @node_rewriter([Elemwise]) def local_exp_log_nan_switch(fgraph, node): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 2d1d12ddf2..3b7d4c2447 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3521,3 +3521,17 @@ def test_infer_shape(self): [x1, x2], self.op_class, ) + + +def test_logdiffexp(): + mode = get_default_mode().including("log_diff_exp") + x = fmatrix() + y = fmatrix() + f = function([x, y], log(exp(x) - exp(y)), mode=mode) + graph = f.maker.fgraph.toposort() + ops_graph = [ + node + for node in graph + if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Exp) + ] + assert len(ops_graph) != 2