Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added KanrenRelationSub for distributive rewrites #634

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

kc611
Copy link
Contributor

@kc611 kc611 commented Oct 28, 2021

This PR adds the following optimizations

  • xa + ya + z*a -> (x + y + z)*a
  • x/a + y/a + z/a -> (x + y + z)/a
  • xa - ya -> (x - y)*a
  • x/a - y/a -> (x - y)/a

Resolves #606

The graph now returns

import aesara
import aesara.tensor as at

eta_at = at.scalar("eta")
kappa_at = at.scalar("kappa")

graph_at = eta_at / kappa_at + (1 - eta_at) / kappa_at
graph_fn = aesara.function([eta_at, kappa_at], graph_at)

aesara.dprint(graph_fn.maker.fgraph)
# Elemwise{reciprocal,no_inplace} [id A] ''   0
#  |kappa [id B]

Gist elaborating the implementation:
https://gist.github.com/kc611/b33e45ed2086597ed9c9df4f387c84b0

@kc611 kc611 marked this pull request as draft October 28, 2021 16:15
@brandonwillard
Copy link
Member

brandonwillard commented Oct 28, 2021

See the comments at the end of this reply. They refer to the fact that the distributive identity (i.e. what's being implemented here) comes with some numeric concerns.

By adding general distributive identities, we simplify the rewrite process, but we might also introduce issues (e.g. like the kinds mitigated by #275 and Kahan summation).

One way to approach this is to perform the distributive rewrite only as an intermediate rewrite within a larger sequence of rewrites that guarantee cancelation of terms.

This doesn't necessarily prevent such rewrites, but we need to consider the trade-offs and how we can navigate and possibly even avoid them.

@@ -744,6 +744,15 @@ def add_compile_configvars():
in_c_key=False,
)

config.add(
"fastmath_opts",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this configuration which will make the rewrite optional.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, don't make them canonicalizations (i.e. don't use register_canonicalize) and register them manually to a different optimization DB. As it currently stands, local_add_sub_collector is being called unnecessarily far too often.

In general, optimization selection and filtering is accomplished using the optimizations DBs and tags.

@codecov
Copy link

codecov bot commented Oct 30, 2021

Codecov Report

Merging #634 (9d5537b) into main (240827c) will decrease coverage by 0.17%.
The diff coverage is 100.00%.

❗ Current head 9d5537b differs from pull request most recent head 5a864a2. Consider uploading reports for the commit 5a864a2 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #634      +/-   ##
==========================================
- Coverage   78.35%   78.18%   -0.18%     
==========================================
  Files         152      152              
  Lines       47685    47682       -3     
  Branches    10881    10882       +1     
==========================================
- Hits        37364    37280      -84     
- Misses       7773     7844      +71     
- Partials     2548     2558      +10     
Impacted Files Coverage Δ
aesara/tensor/math_opt.py 86.64% <100.00%> (+0.41%) ⬆️
aesara/graph/type.py 75.92% <0.00%> (-3.23%) ⬇️
aesara/compile/debugmode.py 57.42% <0.00%> (-3.07%) ⬇️
aesara/tensor/type_other.py 80.76% <0.00%> (-2.38%) ⬇️
aesara/tensor/sharedvar.py 82.22% <0.00%> (-1.46%) ⬇️
aesara/tensor/basic.py 86.25% <0.00%> (-1.21%) ⬇️
aesara/compile/function/pfunc.py 82.25% <0.00%> (-1.08%) ⬇️
aesara/tensor/type.py 91.25% <0.00%> (-1.00%) ⬇️
aesara/tensor/basic_opt.py 84.37% <0.00%> (-0.78%) ⬇️
aesara/sparse/type.py 70.66% <0.00%> (-0.60%) ⬇️
... and 35 more

@@ -744,6 +744,15 @@ def add_compile_configvars():
in_c_key=False,
)

config.add(
"fastmath_opts",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, don't make them canonicalizations (i.e. don't use register_canonicalize) and register them manually to a different optimization DB. As it currently stands, local_add_sub_collector is being called unnecessarily far too often.

In general, optimization selection and filtering is accomplished using the optimizations DBs and tags.

@kc611
Copy link
Contributor Author

kc611 commented Nov 3, 2021

Alright, so I registered onto a SequenceDB but when I do:

import aesara
import aesara.tensor as at
from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.math_opt import fastmath_db
from aesara.graph.fg import FunctionGraph

eta_at = at.scalar("eta")
kappa_at = at.scalar("kappa")

graph_at = eta_at / kappa_at + (1 - eta_at) / kappa_at
graph_fn = FunctionGraph(
        inputs = [eta_at, kappa_at],
        outputs=[graph_at],
    )

fastmath_db.query(OptimizationQuery(include=["basic"])).optimize(graph_fn) #AttributeError: 'FromFunctionLocalOptimizer' object has no attribute 'optimize'

aesara.dprint(graph_fn)

Is it not supposed to be used like that ?

@kc611
Copy link
Contributor Author

kc611 commented Nov 3, 2021

xa + ya + za -> (x + y + z)a
x/a + y/a + z/a -> (x + y + z)/a
xa - ya -> (x - y)*a
x/a - y/a -> (x - y)/a

Is this particular collection of rewrites subject to #649, should we be separating them out?

@brandonwillard
Copy link
Member

Is this particular collection of rewrites subject to #649, should we be separating them out?

Do you mean the +/- or arity difference? The +/- difference is a case of #648. Otherwise, these rewrites should be able to handle all arities.

@brandonwillard
Copy link
Member

Now that we have kanren support, this would be a good exercise for that.

Using kanren would allow us to more succinctly perform all the intermediate computations without the need for potentially (numerically) destabilizing canonicalizations.

We would need to devise a kanren relation (well, a goal constructor) that uses the properties implemented here (i.e. distributive properties) to "search" for the reductions described in #606. We would need the resulting goals to succeed only when at least one reduction has been made. That could be a little tricky to do entirely in miniKanren, though.

Anyway, we can start discussing it here (or somewhere else).

@kc611
Copy link
Contributor Author

kc611 commented Jan 1, 2022

I made a kanren version of those distributive optimizations but it ind-of seems like it isn't able to generalize them to an arbitrary number of distributive optimizations. Have a look at https://gist.github.com/kc611/b33e45ed2086597ed9c9df4f387c84b0

@brandonwillard
Copy link
Member

I made a kanren version of those distributive optimizations but it ind-of seems like it isn't able to generalize them to an arbitrary number of distributive optimizations. Have a look at https://gist.github.com/kc611/b33e45ed2086597ed9c9df4f387c84b0

Just added a comment in that Gist.

aesara/tensor/math_opt.py Outdated Show resolved Hide resolved
@kc611 kc611 changed the title Added local_add_sub_collector optimization Added KanrenRelationSub for distributive rewrites Jan 4, 2022
@kc611 kc611 marked this pull request as ready for review January 10, 2022 06:08
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make the kanren goals more flexible with respect to the supported Ops, instead of creating goals for each one separately. For instance, make the Op a logic variable op_lv, add a goal that confirms the value of op_lv is one of the accepted Ops, and use op_lv in place of the Op.

To perform the check in the second step, conde can be used (e.g. conde([eq(op_lv, at.mul)], [eq(op_lv, at.true_div)], ...)), so can type constraints (e.g. isinstanceo). Just note that if you use something like the examples I just gave, they each have their own limitations (e.g. conde testing for at.mul will only work for those exact at.mul Op instances).

"orig_operation, optimized_operation",
[
(a_at * x_at + a_at * y_at, a_at * (x_at + y_at)),
(x_at * a_at + y_at * x_at, (x_at + y_at) * a_at),
Copy link
Contributor Author

@kc611 kc611 Jan 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I managed to generalize them for div and mul but it seems that these rewrites aren't taking the facts() into account. For instance this particular case passes but the case above it doesn't. Isn't this supposed to be handled by fact(commutative, at.mul) in the rewrite ? (or does that stand for something else)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might have to do with the eq goal that's being used at specific steps (even behind the scenes). There are special eq_* goals that do (and don't) take into account the associativity/commutativity (AC) information set by facts, and you may need to use those explicitly at certain points. Just try not to use them when they're not needed; otherwise, the streams resulting from the goals will become very long due to all the permutations induced by AC relations.

@kc611 kc611 force-pushed the new_opts branch 2 times, most recently from c9ac672 to 31da77d Compare January 26, 2022 14:27
@twiecki
Copy link
Contributor

twiecki commented Jan 26, 2022

This might be the wrong place for this discussion, but there are also rewrites with respect to shapes that can be optimized. For example b * M * a where where a and b are scalars and M is a matrix are more efficient to rewrite to b * a * M so that there is only a single matrix multiplication, instead of two.

@kc611
Copy link
Contributor Author

kc611 commented Jan 26, 2022

If a and b are constants I guess that case would be handled by constant folding. But yeah, we would need a rewrite when they're singular variables if we are to implement something like that.

@brandonwillard
Copy link
Member

This might be the wrong place for this discussion, but there are also rewrites with respect to shapes that can be optimized. For example b * M * a where where a and b are scalars and M is a matrix are more efficient to rewrite to b * a * M so that there is only a single matrix multiplication, instead of two.

Are you talking about transforming ((b * M) * a) to ((b * a) * M) so that there's only one matrix/scalar product? We have some things like that for sums and products in aesara.tensor.math_opt, and the AlgebraicCanonizer does some similar things, but it looks like we might not have a rewrite that covers that exact case.

Here's a way you can check these kinds of things:

import aesara

import aesara.tensor as at
from aesara.graph.opt_utils import optimize_graph


a, b = at.scalars("ab")
M = at.matrix("M")

z = (a * M) * b

aesara.dprint(z)
# Elemwise{mul,no_inplace} [id A] ''
#  |Elemwise{mul,no_inplace} [id B] ''
#  | |InplaceDimShuffle{x,x} [id C] ''
#  | | |a [id D]
#  | |M [id E]
#  |InplaceDimShuffle{x,x} [id F] ''
#    |b [id G]

# This will only perform canonicalizations, but others can be added via the
# `include` keyword
z_opt = optimize_graph(z)

aesara.dprint(z_opt)
# Elemwise{mul,no_inplace} [id A] ''
#  |InplaceDimShuffle{x,x} [id B] ''
#  | |a [id C]
#  |M [id D]
#  |InplaceDimShuffle{x,x} [id E] ''
#    |b [id F]


# This will perform all the default optimizations
z_fn = aesara.function([a, b, M], z)

aesara.dprint(z_fn)
# Elemwise{mul,no_inplace} [id A] ''   2
#  |InplaceDimShuffle{x,x} [id B] ''   1
#  | |a [id C]
#  |M [id D]
#  |InplaceDimShuffle{x,x} [id E] ''   0
#    |b [id F]

@ricardoV94
Copy link
Contributor

Are you guys talking about #287?

@brandonwillard
Copy link
Member

brandonwillard commented Jan 26, 2022

Are you guys talking about #287?

Essentially, yes, but for multiplication, which is probably something we could consider implementing quickly.

@twiecki
Copy link
Contributor

twiecki commented Jan 27, 2022

Are you talking about transforming ((b * M) * a) to ((b * a) * M) so that there's only one matrix/scalar product?

Yes, exactly.


# This does the optimization
# 1. (x + y + z) * A = x * A + y * A + z * A
# 2. (x + y + z) / A = x / A + y / A + z / A
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this lead to numerical issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, the actual optimization being done here is reverse of that:

# Line 3545
distribute_over_add_opt = KanrenRelationSub(lambda x, y: distribute_over_add(y, x))

So it combines the additive terms not distributes it. It's just that it's easier to represent it in this way in kanren and it works just as well if we implement it the other way round purely in kanren.

@kc611
Copy link
Contributor Author

kc611 commented Feb 26, 2022

Alright so it seems that the optimizations work with arbitrary ordering of the common terms. All of these test cases work:

x_at * a_at + a_at * y_at = a_at * (x_at + y_at)
a_at * x_at + y_at * a_at = a_at * (x_at + y_at)
a_at * x_at + a_at * y_at = a_at * (x_at + y_at)
x_at * a_at + y_at * a_at = (x_at + y_at) * a_at
x_at / a_at + y_at / a_at = (x_at + y_at) / a_at
a_at * x_at - a_at * y_at = a_at * (x_at - y_at)
x_at * a_at - y_at * a_at = (x_at - y_at) * a_at
a_at * x_at - y_at * a_at = a_at * (x_at - y_at)
a_at * x_at - a_at * y_at = (x_at - y_at) * a_at
x_at / a_at - y_at / a_at = (x_at - y_at) / a_at

The current issue I'm trying to work on is to extrapolate singular terms into 1 * term:
For instance an Elemwise with a constant term in it:

x * A + A = x * A + 1 * A = (x + 1) * A

@brandonwillard Any ideas on how to do this using goals ? Or is it even possible, to add the extra terms inside the Kanren relations filters .

aesara/tensor/math_opt.py Outdated Show resolved Hide resolved
aesara/tensor/math_opt.py Outdated Show resolved Hide resolved
aesara/tensor/math_opt.py Outdated Show resolved Hide resolved
aesara/tensor/math_opt.py Outdated Show resolved Hide resolved
aesara/tensor/math_opt.py Outdated Show resolved Hide resolved
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is turning out to be an amazing example of genuine relational programming (and kanren)!

If it's possible to abstract/parameterize the implementation of distribute_over_add so that it also covers subtraction, the footprint of these features would be impressively small.

# x_lv, if any, is the logic variable version of some term x,
# while x_at, if any, is the Aesara tensor version for the same.
def distributive_collect(in_lv, out_lv):
from kanren import eq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a local import?


# x_lv, if any, is the logic variable version of some term x,
# while x_at, if any, is the Aesara tensor version for the same.
def distributive_collect(in_lv, out_lv):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a doc-string.

@twiecki
Copy link
Contributor

twiecki commented Mar 21, 2022

This PR is very powerful in its conciseness and the future it demonstrates. It really brings @brandonwillard's vision to the light. So I think we should turn this into a blog post that describes what's happening, how powerful it is, how unique it is, etc.

@kc611 kc611 force-pushed the new_opts branch 2 times, most recently from 0ae8664 to e55dcf3 Compare March 23, 2022 17:38
@brandonwillard
Copy link
Member

Looks like the tests are taking too long to finish. I can't tell if it's GitHub Actions or not, yet.

@kc611
Copy link
Contributor Author

kc611 commented Apr 2, 2022

Alright so it seems like there are two separate issues in the failing test over here:

  1. Things like a + a being optimized into a * (1 + 1). Now this technically is correct however it's not an optimization so should be fixed, this would need a condition in kanren corresponding to at-least one. The context in which this would be used is iterating over the contents of the logic variable cdr_lv in the distributive_collect optimization and checking if at-least one of the values is not equal to A_lv
  2. The bad-view map error. This seems to be happening simply because the optimization has no way of knowing if there are alias between outputs in that particular case. Last time this happened I remember solving this test case by simply changing the position of the optimization. Not sure what the true fix here should be, though.

@rlouf
Copy link
Member

rlouf commented Sep 7, 2022

I think (1) can be "solved" when what is being discussed here is implemented. Constraints like the one you suggested should be kept out of miniKanren goals imho.

Ideally we would apply all the possible rewrites and then use a scoring function to choose the "optimal" graph. We could also adopt a greedy approach a choose the "optimal" rewrite out of all possible rewrites at a given step.

@brandonwillard
Copy link
Member

Ideally we would apply all the possible rewrites and then use a scoring function to choose the "optimal" graph. We could also adopt a greedy approach a choose the "optimal" rewrite out of all possible rewrites at a given step.

In this case, we would like to avoid the cost of re-searching the AC "graph space" on each application of these rewrites.

The e-graph data structures mentioned in #1082 could help with this, albeit not in a direct way perhaps. #1165 is a little closer, though, because it sets us up for rewrite caching and the like. #1165 also helps with #1082 and the use of e-graph-like data structures.

@rlouf rlouf force-pushed the new_opts branch 2 times, most recently from b160ef5 to 67e1262 Compare October 17, 2022 08:39
@rlouf
Copy link
Member

rlouf commented Oct 17, 2022

Rebased this on main and resolved the merge conflicts that appeared after #1054

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing rational function simplifications
5 participants