-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added KanrenRelationSub for distributive rewrites #634
base: main
Are you sure you want to change the base?
Conversation
See the comments at the end of this reply. They refer to the fact that the distributive identity (i.e. what's being implemented here) comes with some numeric concerns. By adding general distributive identities, we simplify the rewrite process, but we might also introduce issues (e.g. like the kinds mitigated by #275 and Kahan summation). One way to approach this is to perform the distributive rewrite only as an intermediate rewrite within a larger sequence of rewrites that guarantee cancelation of terms. This doesn't necessarily prevent such rewrites, but we need to consider the trade-offs and how we can navigate and possibly even avoid them. |
aesara/configdefaults.py
Outdated
@@ -744,6 +744,15 @@ def add_compile_configvars(): | |||
in_c_key=False, | |||
) | |||
|
|||
config.add( | |||
"fastmath_opts", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this configuration which will make the rewrite optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, don't make them canonicalizations (i.e. don't use register_canonicalize
) and register them manually to a different optimization DB. As it currently stands, local_add_sub_collector
is being called unnecessarily far too often.
In general, optimization selection and filtering is accomplished using the optimizations DBs and tags.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #634 +/- ##
==========================================
- Coverage 78.35% 78.18% -0.18%
==========================================
Files 152 152
Lines 47685 47682 -3
Branches 10881 10882 +1
==========================================
- Hits 37364 37280 -84
- Misses 7773 7844 +71
- Partials 2548 2558 +10
|
aesara/configdefaults.py
Outdated
@@ -744,6 +744,15 @@ def add_compile_configvars(): | |||
in_c_key=False, | |||
) | |||
|
|||
config.add( | |||
"fastmath_opts", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, don't make them canonicalizations (i.e. don't use register_canonicalize
) and register them manually to a different optimization DB. As it currently stands, local_add_sub_collector
is being called unnecessarily far too often.
In general, optimization selection and filtering is accomplished using the optimizations DBs and tags.
Alright, so I registered onto a import aesara
import aesara.tensor as at
from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.math_opt import fastmath_db
from aesara.graph.fg import FunctionGraph
eta_at = at.scalar("eta")
kappa_at = at.scalar("kappa")
graph_at = eta_at / kappa_at + (1 - eta_at) / kappa_at
graph_fn = FunctionGraph(
inputs = [eta_at, kappa_at],
outputs=[graph_at],
)
fastmath_db.query(OptimizationQuery(include=["basic"])).optimize(graph_fn) #AttributeError: 'FromFunctionLocalOptimizer' object has no attribute 'optimize'
aesara.dprint(graph_fn) Is it not supposed to be used like that ? |
xa + ya + za -> (x + y + z)a Is this particular collection of rewrites subject to #649, should we be separating them out? |
Now that we have Using We would need to devise a Anyway, we can start discussing it here (or somewhere else). |
I made a kanren version of those distributive optimizations but it ind-of seems like it isn't able to generalize them to an arbitrary number of distributive optimizations. Have a look at https://gist.github.com/kc611/b33e45ed2086597ed9c9df4f387c84b0 |
Just added a comment in that Gist. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can make the kanren
goals more flexible with respect to the supported Op
s, instead of creating goals for each one separately. For instance, make the Op
a logic variable op_lv
, add a goal that confirms the value of op_lv
is one of the accepted Op
s, and use op_lv
in place of the Op
.
To perform the check in the second step, conde
can be used (e.g. conde([eq(op_lv, at.mul)], [eq(op_lv, at.true_div)], ...)
), so can type constraints (e.g. isinstanceo
). Just note that if you use something like the examples I just gave, they each have their own limitations (e.g. conde
testing for at.mul
will only work for those exact at.mul
Op
instances).
tests/tensor/test_math_opt.py
Outdated
"orig_operation, optimized_operation", | ||
[ | ||
(a_at * x_at + a_at * y_at, a_at * (x_at + y_at)), | ||
(x_at * a_at + y_at * x_at, (x_at + y_at) * a_at), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I managed to generalize them for div
and mul
but it seems that these rewrites aren't taking the facts()
into account. For instance this particular case passes but the case above it doesn't. Isn't this supposed to be handled by fact(commutative, at.mul)
in the rewrite ? (or does that stand for something else)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might have to do with the eq
goal that's being used at specific steps (even behind the scenes). There are special eq_*
goals that do (and don't) take into account the associativity/commutativity (AC) information set by facts
, and you may need to use those explicitly at certain points. Just try not to use them when they're not needed; otherwise, the streams resulting from the goals will become very long due to all the permutations induced by AC relations.
c9ac672
to
31da77d
Compare
This might be the wrong place for this discussion, but there are also rewrites with respect to shapes that can be optimized. For example |
If a and b are constants I guess that case would be handled by constant folding. But yeah, we would need a rewrite when they're singular variables if we are to implement something like that. |
Are you talking about transforming Here's a way you can check these kinds of things: import aesara
import aesara.tensor as at
from aesara.graph.opt_utils import optimize_graph
a, b = at.scalars("ab")
M = at.matrix("M")
z = (a * M) * b
aesara.dprint(z)
# Elemwise{mul,no_inplace} [id A] ''
# |Elemwise{mul,no_inplace} [id B] ''
# | |InplaceDimShuffle{x,x} [id C] ''
# | | |a [id D]
# | |M [id E]
# |InplaceDimShuffle{x,x} [id F] ''
# |b [id G]
# This will only perform canonicalizations, but others can be added via the
# `include` keyword
z_opt = optimize_graph(z)
aesara.dprint(z_opt)
# Elemwise{mul,no_inplace} [id A] ''
# |InplaceDimShuffle{x,x} [id B] ''
# | |a [id C]
# |M [id D]
# |InplaceDimShuffle{x,x} [id E] ''
# |b [id F]
# This will perform all the default optimizations
z_fn = aesara.function([a, b, M], z)
aesara.dprint(z_fn)
# Elemwise{mul,no_inplace} [id A] '' 2
# |InplaceDimShuffle{x,x} [id B] '' 1
# | |a [id C]
# |M [id D]
# |InplaceDimShuffle{x,x} [id E] '' 0
# |b [id F] |
Are you guys talking about #287? |
Essentially, yes, but for multiplication, which is probably something we could consider implementing quickly. |
Yes, exactly. |
aesara/tensor/math_opt.py
Outdated
|
||
# This does the optimization | ||
# 1. (x + y + z) * A = x * A + y * A + z * A | ||
# 2. (x + y + z) / A = x / A + y / A + z / A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this lead to numerical issues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah no, the actual optimization being done here is reverse of that:
# Line 3545
distribute_over_add_opt = KanrenRelationSub(lambda x, y: distribute_over_add(y, x))
So it combines the additive terms not distributes it. It's just that it's easier to represent it in this way in kanren
and it works just as well if we implement it the other way round purely in kanren
.
Alright so it seems that the optimizations work with arbitrary ordering of the common terms. All of these test cases work:
The current issue I'm trying to work on is to extrapolate singular terms into
@brandonwillard Any ideas on how to do this using goals ? Or is it even possible, to add the extra terms inside the Kanren relations filters . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is turning out to be an amazing example of genuine relational programming (and kanren
)!
If it's possible to abstract/parameterize the implementation of distribute_over_add
so that it also covers subtraction, the footprint of these features would be impressively small.
aesara/tensor/math_opt.py
Outdated
# x_lv, if any, is the logic variable version of some term x, | ||
# while x_at, if any, is the Aesara tensor version for the same. | ||
def distributive_collect(in_lv, out_lv): | ||
from kanren import eq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a local import?
aesara/tensor/math_opt.py
Outdated
|
||
# x_lv, if any, is the logic variable version of some term x, | ||
# while x_at, if any, is the Aesara tensor version for the same. | ||
def distributive_collect(in_lv, out_lv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a doc-string.
This PR is very powerful in its conciseness and the future it demonstrates. It really brings @brandonwillard's vision to the light. So I think we should turn this into a blog post that describes what's happening, how powerful it is, how unique it is, etc. |
0ae8664
to
e55dcf3
Compare
Looks like the tests are taking too long to finish. I can't tell if it's GitHub Actions or not, yet. |
Alright so it seems like there are two separate issues in the failing test over here:
|
I think (1) can be "solved" when what is being discussed here is implemented. Constraints like the one you suggested should be kept out of miniKanren goals imho. Ideally we would apply all the possible rewrites and then use a scoring function to choose the "optimal" graph. We could also adopt a greedy approach a choose the "optimal" rewrite out of all possible rewrites at a given step. |
In this case, we would like to avoid the cost of re-searching the AC "graph space" on each application of these rewrites. The e-graph data structures mentioned in #1082 could help with this, albeit not in a direct way perhaps. #1165 is a little closer, though, because it sets us up for rewrite caching and the like. #1165 also helps with #1082 and the use of e-graph-like data structures. |
b160ef5
to
67e1262
Compare
Rebased this on |
This PR adds the following optimizations
Resolves #606
The graph now returns
Gist elaborating the implementation:
https://gist.github.com/kc611/b33e45ed2086597ed9c9df4f387c84b0