Skip to content

Commit

Permalink
Fix styling issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Jan 20, 2023
1 parent a05de4a commit 7e76a5e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,22 @@ def log_diff_exp(fgraph, node):

prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op

print("here", node_op, prev_op)
if isinstance(prev_op, aes.Sub) and isinstance(node_op, aes.Log):
a, b = x.owner.inputs
if not a.owner or not b.owner:
return
a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op
a_op, b_op = a.owner.op.scalar_op, b.owner.op.scalar_op
print("here2", a_op, b_op)
if isinstance(a_op, aes.Exp) and isinstance(b_op, aes.Exp):
a = a.owner.inputs[0]
b = b.owner.inputs[0]
print("here3", a, b)
new_out = add(a, log1mexp(sub(b, a)))
old_out = node.outputs[0]
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
print("here4", new_out)
return [new_out]


Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3522,6 +3522,7 @@ def test_infer_shape(self):
self.op_class,
)


def test_logdiffexp():
x = fmatrix()
y = fmatrix()
Expand All @@ -3530,7 +3531,6 @@ def test_logdiffexp():
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Exp)
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Exp)
]
assert len(ops_graph) == 0

0 comments on commit 7e76a5e

Please sign in to comment.