-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement at.eye
using existing Ops
#1217
base: main
Are you sure you want to change the base?
Implement at.eye
using existing Ops
#1217
Conversation
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
…yapunov` to correctly show a TensorVariable is returned
…yapunov` to correctly show a TensorVariable is returned
…nov`. Add a complex check in `_direct_solve_discrete_lyapunov` to try to avoid calling `.conj()`, allowing for a symbolic gradient. Change default `method` parameter in `solve_discrete_lyapunov` from `None` to `direct` Update docstring of `solve_discrete_lyapunov` to explain that `method=direct` should be preferred for compatibility reasons. Add tests for `_direct_solve_discrete_lyapunov` in the real and complex case.
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
…yapunov` to correctly show a TensorVariable is returned
…yapunov` to correctly show a TensorVariable is returned
…nov`. Add a complex check in `_direct_solve_discrete_lyapunov` to try to avoid calling `.conj()`, allowing for a symbolic gradient. Change default `method` parameter in `solve_discrete_lyapunov` from `None` to `direct` Update docstring of `solve_discrete_lyapunov` to explain that `method=direct` should be preferred for compatibility reasons. Add tests for `_direct_solve_discrete_lyapunov` in the real and complex case.
Update `test_basic.test_eye` to reflect the changes to `at.eye`.
I think the Should be fine to remove the specialized backends. I did notice that wile the C version now runs slightly faster, the numba backend became about 2x slower. I guess they have some clever tricks for transpiling JAX is pretty useless except for fixed inputs due to their whole constant shape constraints. But in that case Aesara will just constant fold the output. So no need to worry about it. |
I noticed that |
I think it is
It's just contained in the same Op as IncSubtensor which is a bit confusing (they share a lot of logic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like these commits need to be squashed.
Yes I apologize, I made a separate branch on my aesara fork then wrote this, but it seems it didn't work how I thought? How do you keep clean git branches for working on several small separate projects? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I especially need input on this PR because I don't know what to do with the existing
Eye
Op
. It could be removed, but I am unclear what effects this would have on the numba/jax backends, as well as on existing user code.
If we find a suitable all-around replacement for Eye
, then we can definitely remove it; otherwise, if there is a persistent performance discrepancy between backends, we can always keep Eye
and add a backend-dependent rewrite that replaces it.
Also, I just remembered that we have a non-COp
named AllocDiag
. This work will probably make that redundant as well.
aesara/tensor/basic.py
Outdated
eye = zeros(n * m, dtype=dtype) | ||
eye = aesara.tensor.subtensor.set_subtensor(eye[i :: m + 1], 1).reshape((n, m)) | ||
eye = aesara.tensor.subtensor.set_subtensor(eye[m - k :, :], 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the last set_subtensor
(i.e. that appears to fill with 0
s if we've already used zeros
? I think I might know why, but it would be nice to have only one set_subtensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second set_subtenor
is to prevent the diagonal from "wrapping around" the matrix in cases when m > n
or when k != 0
. This one-liner almost works:
at.subtensor.set_subtensor(eye[i:(m - k)*m:m + 1], 1).reshape((n, m))
But it fails when k < -n
or k > m
. In these cases I guess an error could be raised (why is the user asking for an offset greater than the width of the matrix?), but numpy returns the zero matrix for this case and I wanted to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to match numpy
and avoid the second set_subtensor
would be to check the conditions where the zero matrix should be returned and use multiplication:
eye = at.subtensor.set_subtensor(eye[i:(m - k)*m:m + 1], 1).reshape((n, m))
eye *= (1 - at.gt(k, m)) * (1 - at.lt(k, -n))
But this was about 5µs slower than the two set_subtensor
method when I ran %timeit
. Perhaps it's too cute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an implementation that uses scalars and Composite
:
import numpy as np
import aesara
import aesara.tensor as at
from aesara.scalar import int64, switch, Composite
n = int64("n")
m = int64("m")
k = int64("k")
i = switch(k >= 0, k, -k * m)
eye_sc = at.zeros(n * m, dtype="float64")
mkm = (m - k) * m
eye_sc = at.subtensor.set_subtensor(eye_sc[i : mkm : m + 1], 1).reshape((n, m))
eye_sc_fn = aesara.function([n, m, k], [eye_sc])
aesara.dprint(eye_sc_fn)
# Reshape{2} [id A] 16
# |IncSubtensor{InplaceSet;int64:int64:int64} [id B] 15
# | |Reshape{2} [id C] 14
# | | |IncSubtensor{InplaceSet;int64:int64:int64} [id D] 13
# | | | |Alloc [id E] 11
# | | | | |TensorConstant{0.0} [id F]
# | | | | |TensorFromScalar [id G] 8
# | | | | |mul [id H] 2
# | | | | |n [id I]
# | | | | |m [id J]
# | | | |TensorConstant{1} [id K]
# | | | |Switch [id L] 12
# | | | | |GE [id M] 6
# | | | | | |k [id N]
# | | | | | |ScalarConstant{0} [id O]
# | | | | |k [id N]
# | | | | |mul [id P] 10
# | | | | |neg [id Q] 5
# | | | | | |k [id N]
# | | | | |m [id J]
# | | | |mul [id R] 9
# | | | | |sub [id S] 4
# | | | | | |m [id J]
# | | | | | |k [id N]
# | | | | |m [id J]
# | | | |add [id T] 3
# | | | |m [id J]
# | | | |ScalarConstant{1} [id U]
# | | |MakeVector{dtype='int64'} [id V] 7
# | | |TensorFromScalar [id W] 1
# | | | |n [id I]
# | | |TensorFromScalar [id X] 0
# | | |m [id J]
# | |TensorConstant{1} [id K]
# | |Switch [id L] 12
# | |mul [id R] 9
# | |add [id T] 3
# |MakeVector{dtype='int64'} [id V] 7
%timeit eye_sc_fn(10, 10, 1)
# 66.5 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# Manually construct composite `Op`s for the scalar operations
i_comp_op = Composite([n, m, k], [i])
i_comp = i_comp_op(n, m, k)
mkm_comp_op = Composite([m, k], [mkm])
mkm_comp = mkm_comp_op(m, k)
eye_sc_comp = at.zeros(n * m, dtype="float64")
eye_sc_comp = at.subtensor.set_subtensor(eye_sc_comp[i_comp : mkm_comp : m + 1], 1).reshape((n, m))
eye_sc_comp_fn = aesara.function([n, m, k], [eye_sc_comp])
aesara.dprint(eye_sc_comp_fn)
# Reshape{2} [id A] 10
# |IncSubtensor{InplaceSet;int64:int64:int64} [id B] 9
# | |Alloc [id C] 8
# | | |TensorConstant{0.0} [id D]
# | | |TensorFromScalar [id E] 7
# | | |mul [id F] 6
# | | |n [id G]
# | | |m [id H]
# | |TensorConstant{1} [id I]
# | |Composite{Switch(GE(i2, 0), i2, ((-i2) * i1))} [id J] 5
# | | |n [id G]
# | | |m [id H]
# | | |k [id K]
# | |Composite{((i0 - i1) * i0)} [id L] 4
# | | |m [id H]
# | | |k [id K]
# | |add [id M] 3
# | |m [id H]
# | |ScalarConstant{1} [id N]
# |MakeVector{dtype='int64'} [id O] 2
# |TensorFromScalar [id P] 1
# | |n [id G]
# |TensorFromScalar [id Q] 0
# |m [id H]
%timeit eye_sc_comp_fn(10, 10, 1)
# 60.9 µs ± 2.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before I forget, using Only applies when inc_subtensor
with set_instead_of_inc=True
and ignore_duplicates=True
might be faster.AdvancedIncSubtensor
is used, and it isn't in these cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need an issue for a rewrite that replaces scalar-only sub-graphs with Composite
s.
Done: #1225
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, Switch
is quite literally the scalar version of IfElse
—at least in terms of its Op.c_code
implementation. If anything, we should make invocations of ifelse
return Switch
es when the arguments are ScalarType
ed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the issue for ifelse
: #1226.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, there's no scalar Op
for max
and min
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disregard, obviously this can be done using scalar switch
.
I always create new branches from |
FYI: You'll need to rebase very soon in order to get the fixes in #1219 that address the current CI failures. |
Also, if you're ever having blocking Git difficulties, don't hesitate to ask us. We can always perform the changes ourselves and/or walk you through them. In the former case, you'll need to be comfortable rebasing your local branch, but that's about it. |
I just merged these changes, so here's a good opportunity to try a clean rebase onto |
Add typing information to return value Re-write function using scalar functions for optimization Fix bug where identity matrix incorrectly included 1's in the lower triangle when k > m
…into aesara_native_eye
Ok I hope I got my git situation squared away, although I initally messed up so a bunch of commits from my I re-wrote the function using scalars and from numba import njit
@njit
def numba_eye(n, m=None, k=0):
return np.eye(n, m, k) So there's something more fundamental going on there. The numba If I had to critique my own code I'd say it's repetitive and quite ugly, it might be better to pause this PR, work on #1225, then come back and re-write this function to look more natural. |
There's a lot of overhead to a Aesara function so it's not easy to compare directly to a numba-only function. One thing you can try is to set |
Ah understood. This speeds up the aesara eye function by about 7x, so I see what you mean by overhead. Pure numba is still about 20x faster, though. Who knew putting a bunch of 1's into a matrix was so complex. |
Oh, things like that can absolutely end up being complex, especially under the light of performance considerations. |
This PR would close #1177 by implementing
at.eye
using existing aesaraOps
. The test suite related toat.eye
is updated to reflect these changes.I especially need input on this PR because I don't know what to do with the existing
Eye
Op
. It could be removed, but I am unclear what effects this would have on the numba/jax backends, as well as on existing user code.Implement
at.eye
using existingOps
#1217 (comment) implies that these changes might adversely affect Numba, so let's get an idea of why that could be the case, and consider some alternative formulations.