Skip to content

Commit

Permalink
Generalized distributive rewrites.
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Jan 26, 2022
1 parent 249d0c3 commit 8204a35
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 103 deletions.
130 changes: 28 additions & 102 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from etuples import etuple
from kanren.assoccomm import assoc_flatten, associative, commutative
from kanren.core import lall
from kanren.core import lall, lany
from kanren.graph import mapo
from unification import vars as lvars

Expand Down Expand Up @@ -3506,121 +3506,64 @@ def local_reciprocal_1_plus_exp(fgraph, node):

# x_lv, if any, is the logic variable version of some term x,
# while x_at, if any, is the Aesara tensor version for the same.
def distribute_mul_over_add(in_lv, out_lv):
def distribute_over_add(in_lv, out_lv):
from kanren import conso, eq, fact, heado, tailo

# This does the optimization A * (x + y + z) = A * x + A * y + A * z
A_lv, add_term_lv, add_cdr_lv, mul_cdr_lv, add_flat_lv = lvars(5)
# This does the optimization
# 1. (x + y + z) * A = x * A + y * A + z * A
# 2. (x + y + z) / A = x / A + y / A + z / A
A_lv, op_lv, add_term_lv, add_cdr_lv, mul_cdr_lv, add_flat_lv = lvars(6)
fact(commutative, at.mul)
fact(associative, at.mul)
fact(commutative, at.add)
fact(associative, at.add)
return lall(
# Make sure the input is a `at.mul`
eq(in_lv, etuple(at.mul, A_lv, add_term_lv)),
# Make sure the term being `at.mul`ed is an `add`
# Make sure the input is a `at.mul` or `at.true_div`
eq(in_lv, etuple(op_lv, A_lv, add_term_lv)),
lany(eq(op_lv, at.mul), eq(op_lv, at.true_div)),
# Make sure the outer term being is an `add`
heado(at.add, add_term_lv),
# Flatten the associative pairings of `add` operations
assoc_flatten(add_term_lv, add_flat_lv),
# Get the flattened `add` arguments
tailo(add_cdr_lv, add_flat_lv),
# Add all the `at.mul`ed arguments and set the output
# Add all the arguments and set the output
conso(at.add, mul_cdr_lv, out_lv),
# Apply the `at.mul` to all the flattened `add` arguments
mapo(lambda x, y: conso(at.mul, etuple(A_lv, x), y), add_cdr_lv, mul_cdr_lv),
mapo(lambda x, y: conso(op_lv, etuple(x, A_lv), y), add_cdr_lv, mul_cdr_lv),
)


distribute_mul_over_add_opt = KanrenRelationSub(
lambda x, y: distribute_mul_over_add(y, x)
)
distribute_mul_over_add_opt.__name__ = distribute_mul_over_add.__name__


def distribute_div_over_add(in_lv, out_lv):
from kanren import conso, eq, fact, heado, tailo

# This does the optimization (x + y + z) / A = A / x + A / y + A / z
A_lv, add_term_lv, add_cdr_lv, div_cdr_lv, add_flat_lv = lvars(5)
fact(commutative, at.add)
fact(associative, at.add)
return lall(
# Make sure the input is a `at.div`
eq(in_lv, etuple(at.true_div, add_term_lv, A_lv)),
# Make sure the term being `at.div`ed is an `add`
heado(at.add, add_term_lv),
# Flatten the associative pairings of `add` operations
assoc_flatten(add_term_lv, add_flat_lv),
# Get the flattened `add` arguments
tailo(add_cdr_lv, add_flat_lv),
# Add all the `at.div`ed arguments and set the output
conso(at.add, div_cdr_lv, out_lv),
# Apply the `at.div` to all the flattened `add` arguments
mapo(
lambda x, y: conso(at.true_div, etuple(x, A_lv), y), add_cdr_lv, div_cdr_lv
),
)


distribute_div_over_add_opt = KanrenRelationSub(
lambda x, y: distribute_div_over_add(y, x)
)
distribute_div_over_add_opt.__name__ = distribute_div_over_add.__name__
distribute_over_add_opt = KanrenRelationSub(lambda x, y: distribute_over_add(y, x))
distribute_over_add_opt.__name__ = distribute_over_add.__name__


def distribute_mul_over_sub(in_lv, out_lv):
def distribute_over_sub(in_lv, out_lv):
from kanren import eq, fact

fact(commutative, at.mul)
fact(associative, at.mul)
a_lv, x_lv, y_lv = lvars(3)

return lall(
# lhs == a * x - a * y
eq(
etuple(
at.sub,
etuple(at.mul, a_lv, x_lv),
etuple(at.mul, a_lv, y_lv),
),
in_lv,
),
# rhs == a * (x - y)
eq(
etuple(at.mul, a_lv, etuple(at.sub, x_lv, y_lv)),
out_lv,
),
)


distribute_mul_over_sub_opt = KanrenRelationSub(distribute_mul_over_sub)
distribute_mul_over_sub_opt.__name__ = distribute_mul_over_sub.__name__


def distribute_div_over_sub(in_lv, out_lv):
from kanren import eq

a_lv, x_lv, y_lv = lvars(3)
a_lv, x_lv, y_lv, op_lv = lvars(4)
return lall(
# lhs == x / a - y / a
# lhs == x / a - y / a or x * a - y * a
eq(
etuple(
at.sub,
etuple(at.true_div, x_lv, a_lv),
etuple(at.true_div, y_lv, a_lv),
etuple(op_lv, x_lv, a_lv),
etuple(op_lv, y_lv, a_lv),
),
in_lv,
),
# rhs == (x + y) / a
lany(eq(op_lv, at.mul), eq(op_lv, at.true_div)),
# rhs == (x - y) / a or (x - y) * a
eq(
etuple(at.true_div, etuple(at.add, x_lv, y_lv), a_lv),
etuple(op_lv, etuple(at.sub, x_lv, y_lv), a_lv),
out_lv,
),
)


distribute_div_over_sub_opt = KanrenRelationSub(distribute_div_over_sub)
distribute_div_over_sub_opt.__name__ = distribute_div_over_sub.__name__
distribute_over_sub_opt = KanrenRelationSub(distribute_over_sub)
distribute_over_sub_opt.__name__ = distribute_over_sub.__name__

# 1 - sigmoid(x) -> sigmoid(-x)
local_1msigmoid = PatternSub(
Expand Down Expand Up @@ -3673,35 +3616,18 @@ def distribute_div_over_sub(in_lv, out_lv):
optdb.register("fastmath", fastmath, 1, "fast_run", "mul")

fastmath.register(
"dist_mul_over_add_opt",
in2out(distribute_mul_over_add_opt, ignore_newtrees=True),
"dist_over_add_opt",
in2out(distribute_over_add_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"mul",
)

fastmath.register(
"dist_div_over_add_opt",
in2out(distribute_div_over_add_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"div",
)

fastmath.register(
"dist_mul_over_sub_opt",
in2out(distribute_mul_over_sub_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"div",
)

fastmath.register(
"dist_div_over_sub_opt",
in2out(distribute_div_over_sub_opt, ignore_newtrees=True),
"dist_over_sub_opt",
in2out(distribute_over_sub_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
Expand Down
4 changes: 3 additions & 1 deletion tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4583,8 +4583,10 @@ class TestDistributiveOpts:
"orig_operation, optimized_operation",
[
(a_at * x_at + a_at * y_at, a_at * (x_at + y_at)),
(x_at * a_at + y_at * x_at, (x_at + y_at) * a_at),
(x_at / a_at + y_at / a_at, (x_at + y_at) / a_at),
((a_at * x_at - a_at * y_at, a_at * (x_at - y_at))),
(a_at * x_at - a_at * y_at, a_at * (x_at - y_at)),
(x_at * a_at - y_at * a_at, (x_at - y_at) * a_at),
(x_at / a_at - y_at / a_at, (x_at - y_at) / a_at),
],
)
Expand Down

0 comments on commit 8204a35

Please sign in to comment.