Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement at.eye using existing Ops #1217

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Contributor

@jessegrabowski jessegrabowski commented Sep 27, 2022

This PR would close #1177 by implementing at.eye using existing aesara Ops. The test suite related to at.eye is updated to reflect these changes.

I especially need input on this PR because I don't know what to do with the existing Eye Op. It could be removed, but I am unclear what effects this would have on the numba/jax backends, as well as on existing user code.

  • Investigate C and Numba performance under these changes.
    Implement at.eye using existing Ops #1217 (comment) implies that these changes might adversely affect Numba, so let's get an idea of why that could be the case, and consider some alternative formulations.

jessegrabowski and others added 20 commits June 29, 2022 22:30
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
…yapunov` to correctly show a TensorVariable is returned
…yapunov` to correctly show a TensorVariable is returned
…nov`. Add a complex check in `_direct_solve_discrete_lyapunov` to try to avoid calling `.conj()`, allowing for a symbolic gradient.

Change default `method` parameter in `solve_discrete_lyapunov` from `None` to `direct`

Update docstring of `solve_discrete_lyapunov` to explain that `method=direct` should be preferred for compatibility reasons.

Add tests for `_direct_solve_discrete_lyapunov` in the real and complex case.
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
…yapunov` to correctly show a TensorVariable is returned
…yapunov` to correctly show a TensorVariable is returned
…nov`. Add a complex check in `_direct_solve_discrete_lyapunov` to try to avoid calling `.conj()`, allowing for a symbolic gradient.

Change default `method` parameter in `solve_discrete_lyapunov` from `None` to `direct`

Update docstring of `solve_discrete_lyapunov` to explain that `method=direct` should be preferred for compatibility reasons.

Add tests for `_direct_solve_discrete_lyapunov` in the real and complex case.
Update `test_basic.test_eye` to reflect the changes to `at.eye`.
@ricardoV94
Copy link
Contributor

ricardoV94 commented Sep 27, 2022

I especially need input on this PR because I don't know what to do with the existing Eye Op. It could be removed, but I am unclear what effects this would have on the numba/jax backends, as well as on existing user code.

I think the Op can be removed. Doubt anyone is interaction with it directly. If not, we can add a DeprecationWarning to the class (in __new__, __init__, make_node?)

Should be fine to remove the specialized backends.

I did notice that wile the C version now runs slightly faster, the numba backend became about 2x slower. I guess they have some clever tricks for transpiling np.eye.

JAX is pretty useless except for fixed inputs due to their whole constant shape constraints. But in that case Aesara will just constant fold the output. So no need to worry about it.

@jessegrabowski
Copy link
Contributor Author

jessegrabowski commented Sep 27, 2022

I did notice that wile the C version now runs slightly faster, the numba backend became about 2x slower. I guess they have some clever tricks for transpiling np.eye. JAX is pretty useless except for fixed inputs due to their whole constant shape constraints. But in that case Aesara will just constant fold the output. So no need to worry about it.

I noticed that aesara.tensor.subtensor.set_subtensor is not among the functions implemented in aesara.link.jax.dispatch.subtensor'. Since this function uses set_subtensor, will this prevent transpilation?

@ricardoV94
Copy link
Contributor

ricardoV94 commented Sep 27, 2022

I think it is

def jax_funcify_IncSubtensor(op, **kwargs):

It's just contained in the same Op as IncSubtensor which is a bit confusing (they share a lot of logic)

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like these commits need to be squashed.

@jessegrabowski
Copy link
Contributor Author

Yes I apologize, I made a separate branch on my aesara fork then wrote this, but it seems it didn't work how I thought? How do you keep clean git branches for working on several small separate projects?

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I especially need input on this PR because I don't know what to do with the existing Eye Op. It could be removed, but I am unclear what effects this would have on the numba/jax backends, as well as on existing user code.

If we find a suitable all-around replacement for Eye, then we can definitely remove it; otherwise, if there is a persistent performance discrepancy between backends, we can always keep Eye and add a backend-dependent rewrite that replaces it.

Also, I just remembered that we have a non-COp named AllocDiag. This work will probably make that redundant as well.

aesara/tensor/basic.py Show resolved Hide resolved
aesara/tensor/basic.py Outdated Show resolved Hide resolved
Comment on lines 1320 to 1322
eye = zeros(n * m, dtype=dtype)
eye = aesara.tensor.subtensor.set_subtensor(eye[i :: m + 1], 1).reshape((n, m))
eye = aesara.tensor.subtensor.set_subtensor(eye[m - k :, :], 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the last set_subtensor (i.e. that appears to fill with 0s if we've already used zeros? I think I might know why, but it would be nice to have only one set_subtensor.

Copy link
Contributor Author

@jessegrabowski jessegrabowski Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second set_subtenor is to prevent the diagonal from "wrapping around" the matrix in cases when m > n or when k != 0. This one-liner almost works:

at.subtensor.set_subtensor(eye[i:(m - k)*m:m + 1], 1).reshape((n, m))

But it fails when k < -n or k > m. In these cases I guess an error could be raised (why is the user asking for an offset greater than the width of the matrix?), but numpy returns the zero matrix for this case and I wanted to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to match numpy and avoid the second set_subtensor would be to check the conditions where the zero matrix should be returned and use multiplication:

    eye = at.subtensor.set_subtensor(eye[i:(m - k)*m:m + 1], 1).reshape((n, m))
    eye *= (1 - at.gt(k, m)) * (1 - at.lt(k, -n))

But this was about 5µs slower than the two set_subtensor method when I ran %timeit. Perhaps it's too cute?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an implementation that uses scalars and Composite:

import numpy as np
import aesara
import aesara.tensor as at
from aesara.scalar import int64, switch, Composite


n = int64("n")
m = int64("m")
k = int64("k")

i = switch(k >= 0, k, -k * m)

eye_sc = at.zeros(n * m, dtype="float64")
mkm = (m - k) * m
eye_sc = at.subtensor.set_subtensor(eye_sc[i : mkm : m + 1], 1).reshape((n, m))

eye_sc_fn = aesara.function([n, m, k], [eye_sc])

aesara.dprint(eye_sc_fn)
# Reshape{2} [id A] 16
#  |IncSubtensor{InplaceSet;int64:int64:int64} [id B] 15
#  | |Reshape{2} [id C] 14
#  | | |IncSubtensor{InplaceSet;int64:int64:int64} [id D] 13
#  | | | |Alloc [id E] 11
#  | | | | |TensorConstant{0.0} [id F]
#  | | | | |TensorFromScalar [id G] 8
#  | | | |   |mul [id H] 2
#  | | | |     |n [id I]
#  | | | |     |m [id J]
#  | | | |TensorConstant{1} [id K]
#  | | | |Switch [id L] 12
#  | | | | |GE [id M] 6
#  | | | | | |k [id N]
#  | | | | | |ScalarConstant{0} [id O]
#  | | | | |k [id N]
#  | | | | |mul [id P] 10
#  | | | |   |neg [id Q] 5
#  | | | |   | |k [id N]
#  | | | |   |m [id J]
#  | | | |mul [id R] 9
#  | | | | |sub [id S] 4
#  | | | | | |m [id J]
#  | | | | | |k [id N]
#  | | | | |m [id J]
#  | | | |add [id T] 3
#  | | |   |m [id J]
#  | | |   |ScalarConstant{1} [id U]
#  | | |MakeVector{dtype='int64'} [id V] 7
#  | |   |TensorFromScalar [id W] 1
#  | |   | |n [id I]
#  | |   |TensorFromScalar [id X] 0
#  | |     |m [id J]
#  | |TensorConstant{1} [id K]
#  | |Switch [id L] 12
#  | |mul [id R] 9
#  | |add [id T] 3
#  |MakeVector{dtype='int64'} [id V] 7

%timeit eye_sc_fn(10, 10, 1)
# 66.5 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# Manually construct composite `Op`s for the scalar operations
i_comp_op = Composite([n, m, k], [i])
i_comp = i_comp_op(n, m, k)

mkm_comp_op = Composite([m, k], [mkm])
mkm_comp = mkm_comp_op(m, k)

eye_sc_comp = at.zeros(n * m, dtype="float64")
eye_sc_comp = at.subtensor.set_subtensor(eye_sc_comp[i_comp : mkm_comp : m + 1], 1).reshape((n, m))

eye_sc_comp_fn = aesara.function([n, m, k], [eye_sc_comp])

aesara.dprint(eye_sc_comp_fn)
# Reshape{2} [id A] 10
#  |IncSubtensor{InplaceSet;int64:int64:int64} [id B] 9
#  | |Alloc [id C] 8
#  | | |TensorConstant{0.0} [id D]
#  | | |TensorFromScalar [id E] 7
#  | |   |mul [id F] 6
#  | |     |n [id G]
#  | |     |m [id H]
#  | |TensorConstant{1} [id I]
#  | |Composite{Switch(GE(i2, 0), i2, ((-i2) * i1))} [id J] 5
#  | | |n [id G]
#  | | |m [id H]
#  | | |k [id K]
#  | |Composite{((i0 - i1) * i0)} [id L] 4
#  | | |m [id H]
#  | | |k [id K]
#  | |add [id M] 3
#  |   |m [id H]
#  |   |ScalarConstant{1} [id N]
#  |MakeVector{dtype='int64'} [id O] 2
#    |TensorFromScalar [id P] 1
#    | |n [id G]
#    |TensorFromScalar [id Q] 0
#      |m [id H]

%timeit eye_sc_comp_fn(10, 10, 1)
# 60.9 µs ± 2.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Copy link
Member

@brandonwillard brandonwillard Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before I forget, using inc_subtensor with set_instead_of_inc=True and ignore_duplicates=True might be faster. Only applies when AdvancedIncSubtensor is used, and it isn't in these cases.

Copy link
Member

@brandonwillard brandonwillard Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need an issue for a rewrite that replaces scalar-only sub-graphs with Composites.

Done: #1225

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, Switch is quite literally the scalar version of IfElse—at least in terms of its Op.c_code implementation. If anything, we should make invocations of ifelse return Switches when the arguments are ScalarTypeed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the issue for ifelse: #1226.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, there's no scalar Op for max and min?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregard, obviously this can be done using scalar switch.

@brandonwillard
Copy link
Member

Yes I apologize, I made a separate branch on my aesara fork then wrote this, but it seems it didn't work how I thought? How do you keep clean git branches for working on several small separate projects?

I always create new branches from upstream/main (N.B. your remote for this repo might be named something other than upstream) for every PR: e.g. git checkout -b <new-branch-name> upstream/main. That should also set up remote tracking correctly (e.g. one generally wants to see differences between their branch and upstream/main, because that tells one when they need to rebase and such).

@brandonwillard
Copy link
Member

FYI: You'll need to rebase very soon in order to get the fixes in #1219 that address the current CI failures.

@brandonwillard
Copy link
Member

Also, if you're ever having blocking Git difficulties, don't hesitate to ask us. We can always perform the changes ourselves and/or walk you through them. In the former case, you'll need to be comfortable rebasing your local branch, but that's about it.

@brandonwillard
Copy link
Member

brandonwillard commented Sep 29, 2022

FYI: You'll need to rebase very soon in order to get the fixes in #1219 that address the current CI failures.

I just merged these changes, so here's a good opportunity to try a clean rebase onto <upstream>/main. If it's done correctly, there should be no changes in your commits, unless you need to resolve a conflict, in which case those changes should be included as part of an existing commit. In general, there should not be any new commits added to your branch by such a rebase.

jessegrabowski and others added 3 commits October 1, 2022 15:46
Add typing information to return value
Re-write function using scalar functions for optimization
Fix bug where identity matrix incorrectly included 1's in the lower triangle when k > m
@jessegrabowski
Copy link
Contributor Author

jessegrabowski commented Oct 1, 2022

Ok I hope I got my git situation squared away, although I initally messed up so a bunch of commits from my solve_lyapunov branch are polluting the change history for this branch. Bear with me.

I re-wrote the function using scalars and Composite Ops for the small speed boost it offers. Even with the double set_subtensor, my speed tests are with 1µs of the current implementation. I can confirm @ricardoV94 's tests that the numba backend is somewhat slower in this version than the current implementation, but both are orders of magnitude slower than:

from numba import njit

@njit
def numba_eye(n, m=None, k=0):
    return np.eye(n, m, k)

So there's something more fundamental going on there. The numba eye overload uses looping, and I imagine they lean on their optimized loop lifting to get the huge speedup vs the slice-based methods used in base numpy and in this PR.

If I had to critique my own code I'd say it's repetitive and quite ugly, it might be better to pause this PR, work on #1225, then come back and re-write this function to look more natural.

@ricardoV94
Copy link
Contributor

There's a lot of overhead to a Aesara function so it's not easy to compare directly to a numba-only function.

One thing you can try is to set function.trust_input = True after compiling it. Then make sure you pass correctly typed arrays as inputs. That usually cuts the difference a bit.

@jessegrabowski
Copy link
Contributor Author

There's a lot of overhead to a Aesara function so it's not easy to compare directly to a numba-only function.

One thing you can try is to set function.trust_input = True after compiling it. Then make sure you pass correctly typed arrays as inputs. That usually cuts the difference a bit.

Ah understood. This speeds up the aesara eye function by about 7x, so I see what you mean by overhead. Pure numba is still about 20x faster, though.

Who knew putting a bunch of 1's into a matrix was so complex.

@brandonwillard
Copy link
Member

Who knew putting a bunch of 1's into a matrix was so complex.

Oh, things like that can absolutely end up being complex, especially under the light of performance considerations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Convert Eye to a COp or implement it in terms of existing COps
3 participants