From d5a4653bdbefa25419f2cb31552380cda1710223 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 20 Jan 2023 09:45:26 +0530 Subject: [PATCH 1/4] Add log_diff_exp opt --- aesara/tensor/rewriting/math.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index 2aeb701cf6..907772d5ba 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -311,6 +311,35 @@ def local_exp_log(fgraph, node): return [exp(x)] +@register_specialize +@node_rewriter([Elemwise]) +def log_diff_exp(fgraph, node): + # Case for log(exp(a) - exp(b)) -> a + log1mexp(b - a) + x = node.inputs[0] + + if not isinstance(node.op, Elemwise): + return + if not x.owner or not isinstance(x.owner.op, Elemwise): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + if isinstance(prev_op, aes.Sub) and isinstance(node_op, aes.Log): + a, b = x.owner.inputs + if not a.owner or not b.owner: + return + a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op + if isinstance(a_op, aes.Exp) and isinstance(b_op, aes.Exp): + a = a.owner.inputs[0] + b = b.owner.inputs[0] + new_out = add(a, log1mexp(sub(b, a))) + old_out = node.outputs[0] + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] + + @register_specialize @node_rewriter([Elemwise]) def local_exp_log_nan_switch(fgraph, node): From 9d706294de4e7fd9ccaacea357e7616ee962eee4 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 20 Jan 2023 09:45:39 +0530 Subject: [PATCH 2/4] TST: Add tests for log_diff_exp --- tests/tensor/test_math.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 2d1d12ddf2..656a032eec 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3521,3 +3521,16 @@ def test_infer_shape(self): [x1, x2], self.op_class, ) + +def test_logdiffexp(): + x = fmatrix() + y = fmatrix() + f = function([x, y], log(exp(x) - exp(y))) + graph = f.maker.fgraph.toposort() + ops_graph = [ + node + for node in graph + if isinstance(node.op, Elemwise) + and isinstance(node.op.scalar_op, aes.Exp) + ] + assert len(ops_graph) == 0 From 6d7f154571af3f78d0489e4b2ef46a842812067f Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 20 Jan 2023 17:24:01 +0530 Subject: [PATCH 3/4] Fix styling issues --- aesara/tensor/rewriting/math.py | 2 +- tests/tensor/test_math.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index 907772d5ba..e0b00e5ce8 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -329,7 +329,7 @@ def log_diff_exp(fgraph, node): a, b = x.owner.inputs if not a.owner or not b.owner: return - a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op + a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op if isinstance(a_op, aes.Exp) and isinstance(b_op, aes.Exp): a = a.owner.inputs[0] b = b.owner.inputs[0] diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 656a032eec..15f471d432 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3522,6 +3522,7 @@ def test_infer_shape(self): self.op_class, ) + def test_logdiffexp(): x = fmatrix() y = fmatrix() @@ -3530,7 +3531,6 @@ def test_logdiffexp(): ops_graph = [ node for node in graph - if isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, aes.Exp) + if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Exp) ] assert len(ops_graph) == 0 From 5fbbc8a52b71917940f81a229231d64d7ce1ffcd Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 25 Jan 2023 16:16:32 +0530 Subject: [PATCH 4/4] Use mode --- aesara/tensor/rewriting/math.py | 8 ++++---- tests/tensor/test_math.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index e0b00e5ce8..48df59e461 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -312,13 +312,13 @@ def local_exp_log(fgraph, node): @register_specialize -@node_rewriter([Elemwise]) +@node_rewriter([log]) def log_diff_exp(fgraph, node): - # Case for log(exp(a) - exp(b)) -> a + log1mexp(b - a) + r""" + Rewrite that changes ``log(exp(a) - exp(b))`` to ``a + log1mexp(b - a)``. + """ x = node.inputs[0] - if not isinstance(node.op, Elemwise): - return if not x.owner or not isinstance(x.owner.op, Elemwise): return diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 15f471d432..3b7d4c2447 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3524,13 +3524,14 @@ def test_infer_shape(self): def test_logdiffexp(): + mode = get_default_mode().including("log_diff_exp") x = fmatrix() y = fmatrix() - f = function([x, y], log(exp(x) - exp(y))) + f = function([x, y], log(exp(x) - exp(y)), mode=mode) graph = f.maker.fgraph.toposort() ops_graph = [ node for node in graph if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Exp) ] - assert len(ops_graph) == 0 + assert len(ops_graph) != 2