Skip to content

Commit

Permalink
Added KanrenRelationSub for distributive rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Mar 23, 2022
1 parent 2742315 commit 0ae8664
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 0 deletions.
94 changes: 94 additions & 0 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
from functools import partial, reduce

import numpy as np
from cons import cons
from etuples import etuple
from kanren import fact, heado, tailo
from kanren.assoccomm import associative, commutative
from kanren.core import lall, lany
from kanren.graph import mapo
from unification import vars as lvars

import aesara.scalar.basic as aes
import aesara.scalar.math as aes_math
import aesara.tensor as at
from aesara.compile import optdb
from aesara.graph.basic import Constant, Variable
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import (
LocalOptGroup,
LocalOptimizer,
Expand All @@ -18,6 +28,7 @@
local_optimizer,
)
from aesara.graph.opt_utils import get_clients_at_depth
from aesara.graph.optdb import EquilibriumDB
from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import assert_op
from aesara.tensor.basic import (
Expand Down Expand Up @@ -3520,6 +3531,77 @@ def local_reciprocal_1_plus_exp(fgraph, node):
return out


fact(commutative, aes.mul)
fact(associative, aes.mul)
fact(commutative, aes.add)
fact(associative, aes.add)


def distributive_collect(in_lv, out_lv):
from kanren import eq

"""
x_lv, if any, is the logic variable version of some term x,
while x_at, if any, is the Aesara tensor version for the same.
This logic is for:
1. (x + y + z) * A = x * A + y * A + z * A
2. (x + y + z) / A = x / A + y / A + z / A
3. (x - y) * A = x * A - y * A
4. (x - y) * A = x * A - y * A
The actual optimizations are done in reverse order, (RHS to LHS)
hence reducing the overall calculations.
"""
A_lv, op_lv, all_term_lv, all_cdr_lv, cdr_lv, all_flat_lv = lvars(6)
return lall(
lany(
heado(at.add, all_term_lv),
heado(at.sub, all_term_lv),
),
# Get the flattened `add` arguments
tailo(all_cdr_lv, all_term_lv),
# Add all the arguments and set the output
lany(eq(cons(at.add, cdr_lv), out_lv), eq(cons(at.sub, cdr_lv), out_lv)),
lany(
lall(
lany(
eq(etuple(op_lv, A_lv, all_term_lv), in_lv),
eq(etuple(op_lv, all_term_lv, A_lv), in_lv),
),
eq(op_lv, at.mul),
mapo(
lambda x, y: lany(
lall(eq(x, at.as_tensor_variable(1.0)), eq(y, A_lv)),
lany(
eq(etuple(op_lv, x, A_lv), y),
eq(etuple(op_lv, A_lv, x), y),
),
),
all_cdr_lv,
cdr_lv,
),
),
lall(
eq(etuple(op_lv, all_term_lv, A_lv), in_lv),
eq(op_lv, at.true_div),
mapo(
lambda x, y: eq(etuple(op_lv, x, A_lv), y),
all_cdr_lv,
cdr_lv,
),
),
),
)


distributive_collect_opt = KanrenRelationSub(
lambda x, y: distributive_collect(y, x),
node_filter=lambda x: isinstance(x.op, Elemwise),
)
distributive_collect_opt.__name__ = distributive_collect.__name__


# 1 - sigmoid(x) -> sigmoid(-x)
local_1msigmoid = PatternSub(
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
Expand Down Expand Up @@ -3564,3 +3646,15 @@ def local_reciprocal_1_plus_exp(fgraph, node):
)
register_canonicalize(local_sigmoid_logit)
register_specialize(local_sigmoid_logit)


fastmath = EquilibriumDB()

optdb.register("fastmath", fastmath, "fast_run", position=1)

fastmath.register(
"dist_collect_opt",
in2out(distributive_collect_opt, ignore_newtrees=True),
"distribute_opts",
"fast_run"
)
55 changes: 55 additions & 0 deletions tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4623,3 +4623,58 @@ def logit_fn(x):
fg = optimize(FunctionGraph([x], [out]))
assert not list(fg.toposort())
assert fg.inputs[0] is fg.outputs[0]


class TestDistributiveOpts:
x_at = vector("x")
y_at = vector("y")
a_at = matrix("a")

# x_lv, if any, is the logic variable version of some term x,
# while x_at is the Aesara tensor version for the same.
@pytest.mark.parametrize(
"orig_operation, optimized_operation",
[
(x_at * a_at + a_at * y_at, a_at * (x_at + y_at)),
(a_at * x_at + y_at * a_at, a_at * (x_at + y_at)),
(a_at * x_at + a_at * y_at, a_at * (x_at + y_at)),
(x_at * a_at + y_at * a_at, (x_at + y_at) * a_at),
(x_at * a_at + a_at, (x_at + 1) * a_at),
(x_at / a_at + y_at / a_at, (x_at + y_at) / a_at),
(a_at * x_at - a_at * y_at, a_at * (x_at - y_at)),
(x_at * a_at - y_at * a_at, (x_at - y_at) * a_at),
(a_at * x_at - y_at * a_at, a_at * (x_at - y_at)),
(a_at * x_at - a_at * y_at, (x_at - y_at) * a_at),
(x_at / a_at - y_at / a_at, (x_at - y_at) / a_at),
],
)
def test_distributive_opts(self, orig_operation, optimized_operation):
fgraph = FunctionGraph([self.x_at, self.y_at, self.a_at], [orig_operation])
out_orig = fgraph.outputs[0]

fgraph_res = FunctionGraph(
[self.x_at, self.y_at, self.a_at], [optimized_operation]
)
out_res = fgraph_res.outputs[0]

fgraph_opt = optimize(fgraph)
out_opt = fgraph_opt.outputs[0]

assert all(
[
isinstance(out_orig.owner.op, Elemwise),
isinstance(out_res.owner.op, Elemwise),
isinstance(out_opt.owner.op, Elemwise),
]
)

# The scalar op originally in the output node (The Op to be 'collected').
# Should not be equal to the outer scalar Op in optimized version of graph
original_scalar_op = type(out_orig.owner.op.scalar_op)
# The outer scalar Op in in the resulting graph should
# be equal to the outer scalar Op in optimized version of the graph
resulting_scalar_op = type(out_res.owner.op.scalar_op)
optimized_scalar_op = type(out_opt.owner.op.scalar_op)

assert not original_scalar_op == resulting_scalar_op
assert resulting_scalar_op == optimized_scalar_op

0 comments on commit 0ae8664

Please sign in to comment.