Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fusion rewrite for CAReduces with Elemwise inputs #1285

Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Nov 4, 2022

This PR adds fusion rewrites for CAReduce nodes with Elemwise-derived inputs.

  • Make the Python backend work for the Composite Ops generated by this fusion
  • Do something about CAReduceDtype
    It's a fairly redundant subclass that probably should be merged with CAReduce anyway.
  • Add more/better tests
    • E.g. test the axis parameter
  • Consider only performing the rewrite when not using the Python backend (for performance reasons)
  • [ ] Support multiple inputs (optional)
    This will require some refactoring of CAReduce or a new subclass and should be split off into its own issue/PR. See Fuse CAReduces with multi-input Elemwises #1307.

@brandonwillard brandonwillard marked this pull request as draft November 4, 2022 22:38
@brandonwillard brandonwillard linked an issue Nov 4, 2022 that may be closed by this pull request
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch 2 times, most recently from cbf33e4 to b681459 Compare November 4, 2022 22:56
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch 5 times, most recently from 914f7f6 to c371651 Compare November 6, 2022 05:18
@ricardoV94
Copy link
Contributor

Should we only fuse when the unreduced output has a single client, and therefore is definitely never needed?

@codecov
Copy link

codecov bot commented Nov 6, 2022

Codecov Report

Merging #1285 (91f3438) into main (3ad936f) will increase coverage by 0.03%.
The diff coverage is 94.53%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1285      +/-   ##
==========================================
+ Coverage   74.12%   74.15%   +0.03%     
==========================================
  Files         174      174              
  Lines       48652    48706      +54     
  Branches    10366    10372       +6     
==========================================
+ Hits        36064    36119      +55     
- Misses      10299    10301       +2     
+ Partials     2289     2286       -3     
Impacted Files Coverage Δ
aesara/compile/function/pfunc.py 84.18% <ø> (-0.24%) ⬇️
aesara/compile/function/types.py 79.16% <75.00%> (+0.16%) ⬆️
aesara/tensor/elemwise.py 88.07% <90.54%> (-0.52%) ⬇️
aesara/tensor/rewriting/elemwise.py 86.40% <94.44%> (+0.65%) ⬆️
aesara/scalar/basic.py 79.02% <95.16%> (+0.10%) ⬆️
aesara/compile/mode.py 84.47% <100.00%> (+1.22%) ⬆️
aesara/tensor/math.py 90.40% <100.00%> (+0.37%) ⬆️

@brandonwillard
Copy link
Member Author

Should we only fuse when the unreduced output has a single client, and therefore is definitely never needed?

Yeah, that and a few other things need/needed to be done before this stops being a draft. I just added it now, though—along with another fix.

@brandonwillard
Copy link
Member Author

brandonwillard commented Nov 6, 2022

Some current results:

import numpy as np

import aesara
import aesara.tensor as at

from aesara.compile.mode import get_mode


fusion_mode = get_mode("FAST_RUN").including("local_careduce_fusion")
no_fusion_mode = get_mode("FAST_RUN").excluding("local_careduce_fusion")


x = at.matrix("x")
y = at.exp(x).sum(axis=1)

y_fn = aesara.function([x], y, mode=no_fusion_mode)

aesara.dprint(y_fn)
# Sum{axis=[1], acc_dtype=float64} [id A] 1
#  |Elemwise{exp,no_inplace} [id B] 0
#    |x [id C]

y_fusion_fn = aesara.function([x], y, mode=fusion_mode)

aesara.dprint(y_fusion_fn)
# CAReduce{Composite{(i0 + exp(i1))}}{axis=[1], acc_dtype=float64} [id A] 0
#  |x [id B]

rng = np.random.default_rng(23920)

x_small_val = rng.random((10, 10))
x_large_val = rng.random((5000, 2000))

%timeit y_fn(x_small_val)
# 6.58 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit y_fn(x_large_val)
# 198 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

res = y_fn(x_large_val)
exp_res = np.exp(x_large_val).sum(axis=1)
assert res.shape == exp_res.shape
assert np.allclose(res, exp_res)

%timeit y_fusion_fn(x_small_val)
# 6.25 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit y_fusion_fn(x_large_val)
# 55.3 ms ± 826 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

res = y_fusion_fn(x_large_val)
assert res.shape == exp_res.shape
assert np.allclose(res, exp_res)

@brandonwillard brandonwillard self-assigned this Nov 20, 2022
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch 2 times, most recently from 34ca8c3 to d977ee4 Compare November 21, 2022 01:53
@brandonwillard brandonwillard marked this pull request as ready for review November 21, 2022 01:58
- Lazily create and cache `FunctionGraph`s, the `Composite.perform`
  implementation, C code, and name values
- Use `fgraph_to_python` for `Composite.perform`
- Use the `HasInnerGraph` interface
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fuse CAReduces and Elemwises
2 participants