From 7e76a5e90fb7fc967611be47b4719bd03142a3dd Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 20 Jan 2023 17:22:28 +0530 Subject: [PATCH] Fix styling issues --- aesara/tensor/rewriting/math.py | 7 +++++-- tests/tensor/test_math.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index 907772d5ba..db295c45b0 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -324,19 +324,22 @@ def log_diff_exp(fgraph, node): prev_op = x.owner.op.scalar_op node_op = node.op.scalar_op - + print("here", node_op, prev_op) if isinstance(prev_op, aes.Sub) and isinstance(node_op, aes.Log): a, b = x.owner.inputs if not a.owner or not b.owner: return - a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op + a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op + print("here2", a_op, b_op) if isinstance(a_op, aes.Exp) and isinstance(b_op, aes.Exp): a = a.owner.inputs[0] b = b.owner.inputs[0] + print("here3", a, b) new_out = add(a, log1mexp(sub(b, a))) old_out = node.outputs[0] if new_out.dtype != old_out.dtype: new_out = cast(new_out, old_out.dtype) + print("here4", new_out) return [new_out] diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 656a032eec..15f471d432 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3522,6 +3522,7 @@ def test_infer_shape(self): self.op_class, ) + def test_logdiffexp(): x = fmatrix() y = fmatrix() @@ -3530,7 +3531,6 @@ def test_logdiffexp(): ops_graph = [ node for node in graph - if isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, aes.Exp) + if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Exp) ] assert len(ops_graph) == 0