-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce scalars in compiled graphs via the FusionRewrite #349
Comments
Actually I think since we refactored the FusionRewrite this is already the case: import pytensor
import pytensor.tensor as pt
import pytensor.scalar as ps
import numpy as np
x = pt.scalar("x")
out = pt.exp(pt.cos(pt.log(x)))
fn1 = pytensor.function([x], out)
pytensor.dprint(fn1, print_type=True)
So the only question left is to remove the useless Elemwise. x = ps.float64("x")
op = ps.Composite([x], [ps.exp(ps.cos(ps.log(x)))])
x = pt.scalar("x")
out = pt.tensor_from_scalar(op(pt.scalar_from_tensor(x)))
fn2 = pytensor.function([x], out)
pytensor.dprint(fn2, print_type=True)
But is it worth it? This example in the C backend is actually faster with fn1.trust_input = True
fn2.trust_input = True
x_test = np.array(2.0)
assert fn1(x_test) == fn2(x_test)
%timeit fn1(x_test) # 6.55 µs ± 97.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit fn2(x_test) # 7.43 µs ± 54.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) For numba I get similar speeds for the two graphs. Surprisingly 3.5x slower (20 us) |
During the design meeting it was suggested to remove Elemwises when they are used by Ops that accept scalar inputs and return equivalent results. However, there aren't that many Ops that accept scalar inputs. I can only think of Assert, Subtensor indexes, OpFromGraph (and scalar ops of course). Scans could probably benefit from allowing scalar inputs but that's a tough can of worms due to the complex internal / pre-allocation logic. Would make more sense after a refactoring like #191 |
I am going to close this one, until someone shows how a graph would be more optimized / faster otherwise |
I think the only thing that's needed is to allow the FusionRewrite to work on 0d tensors (right now I think it requires that ndim > 1).
Fusing chains of 0d tensors inside a Composite would have the same effect as using scalars in the graph with a small overhead from
Elemwise
(but which also takes care of the otherwise neededScalarFromTensor
andTensorFromScalar
at the inputs and outputs)?Once we have a 0d Elemwise composite it's also trivial to replace it by the scalar case if that's more efficient.
Originally posted by @ricardoV94 in #345 (comment)
The text was updated successfully, but these errors were encountered: