Skip to content

Commit

Permalink
Added eq_assoccomm to enforce associativity
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Jan 26, 2022
1 parent 8204a35 commit 31da77d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
25 changes: 17 additions & 8 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3507,7 +3507,9 @@ def local_reciprocal_1_plus_exp(fgraph, node):
# x_lv, if any, is the logic variable version of some term x,
# while x_at, if any, is the Aesara tensor version for the same.
def distribute_over_add(in_lv, out_lv):
from kanren import conso, eq, fact, heado, tailo
from cons import cons
from kanren import eq, fact, heado, tailo
from kanren.assoccomm import eq_assoccomm

# This does the optimization
# 1. (x + y + z) * A = x * A + y * A + z * A
Expand All @@ -3519,7 +3521,7 @@ def distribute_over_add(in_lv, out_lv):
fact(associative, at.add)
return lall(
# Make sure the input is a `at.mul` or `at.true_div`
eq(in_lv, etuple(op_lv, A_lv, add_term_lv)),
eq_assoccomm(etuple(op_lv, A_lv, add_term_lv), in_lv),
lany(eq(op_lv, at.mul), eq(op_lv, at.true_div)),
# Make sure the outer term being is an `add`
heado(at.add, add_term_lv),
Expand All @@ -3528,8 +3530,12 @@ def distribute_over_add(in_lv, out_lv):
# Get the flattened `add` arguments
tailo(add_cdr_lv, add_flat_lv),
# Add all the arguments and set the output
conso(at.add, mul_cdr_lv, out_lv),
mapo(lambda x, y: conso(op_lv, etuple(x, A_lv), y), add_cdr_lv, mul_cdr_lv),
eq_assoccomm(cons(at.add, mul_cdr_lv), out_lv),
mapo(
(lambda x, y: eq_assoccomm(cons(op_lv, etuple(x, A_lv)), y)),
add_cdr_lv,
mul_cdr_lv,
),
)


Expand All @@ -3539,23 +3545,26 @@ def distribute_over_add(in_lv, out_lv):

def distribute_over_sub(in_lv, out_lv):
from kanren import eq, fact
from kanren.assoccomm import eq_assoccomm

fact(commutative, at.mul)
fact(associative, at.mul)
a_lv, x_lv, y_lv, op_lv = lvars(4)
a_lv, x_lv, y_lv, op_lv, t1_lv, t2_lv = lvars(6)
return lall(
# lhs == x / a - y / a or x * a - y * a
eq(
etuple(
at.sub,
etuple(op_lv, x_lv, a_lv),
etuple(op_lv, y_lv, a_lv),
t1_lv,
t2_lv,
),
in_lv,
),
eq_assoccomm(etuple(op_lv, x_lv, a_lv), t1_lv),
eq_assoccomm(etuple(op_lv, y_lv, a_lv), t2_lv),
lany(eq(op_lv, at.mul), eq(op_lv, at.true_div)),
# rhs == (x - y) / a or (x - y) * a
eq(
eq_assoccomm(
etuple(op_lv, etuple(at.sub, x_lv, y_lv), a_lv),
out_lv,
),
Expand Down
6 changes: 5 additions & 1 deletion tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4582,11 +4582,15 @@ class TestDistributiveOpts:
@pytest.mark.parametrize(
"orig_operation, optimized_operation",
[
(x_at * a_at + a_at * y_at, a_at * (x_at + y_at)),
(a_at * x_at + y_at * a_at, a_at * (x_at + y_at)),
(a_at * x_at + a_at * y_at, a_at * (x_at + y_at)),
(x_at * a_at + y_at * x_at, (x_at + y_at) * a_at),
(x_at * a_at + y_at * a_at, (x_at + y_at) * a_at),
(x_at / a_at + y_at / a_at, (x_at + y_at) / a_at),
(a_at * x_at - a_at * y_at, a_at * (x_at - y_at)),
(x_at * a_at - y_at * a_at, (x_at - y_at) * a_at),
(a_at * x_at - y_at * a_at, a_at * (x_at - y_at)),
(a_at * x_at - a_at * y_at, (x_at - y_at) * a_at),
(x_at / a_at - y_at / a_at, (x_at - y_at) / a_at),
],
)
Expand Down

0 comments on commit 31da77d

Please sign in to comment.