-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert Eye
to a COp
or implement it in terms of existing COp
s
#1177
Comments
We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no? |
Eye
to a COp
Eye
to a COp
or implement it in terms of existing COp
s
That might not be as performant as a |
I believe the following is an implementation using n = at.iscalar('n')
m = at.iscalar('m')
k = at.iscalar('k')
i = at.switch(k >= 0, k, -k * m)
eye = at.zeros(n * m)
eye = at.set_subtensor(eye[i::m+1], 1).reshape((n, m))
eye = at.set_subtensor(eye[m-k:, :], 0)
Eye = aesara.compile.builders.OpFromGraph([n, m, k], [eye]) The existing def eye(n, m=None, k=0, dtype=None):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
n : int
Number of rows in the output.
m : int, optional
Number of columns in the output. If None, defaults to `N`.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal, and a negative value
to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
ndarray of shape (N,M)
An array where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one.
"""
if dtype is None:
dtype = aesara.config.floatX
if m is None:
m = n
return Eye(n, m, k).astype(dtype) I can add some tests to check corner cases and submit this as a pull request if it looks like I'm barking up the right tree? Where would I add the code to create the |
Looks about right. About where to put it... good question. Floating sounds about right, maybe add an underscore prefix to those intermediate variables? Other thing worth checking is if any rewrites currently target the |
How would I check for rewrites that target |
Sounds about right. There might not be any in which case you are in luck ;) Edit: I didn't find anything either |
You can also get rid of the Numba and JAX dispatch (I assume a dispatch for OpFromGraph is already been implemented) Edit: It seems to be only for Numba... |
Actually we might not even need an |
import aesara
import aesara.tensor as at
def eye_new(n, m=None, k=0, dtype=None):
if m is None:
m = n
if dtype is None:
dtype = aesara.config.floatX
n = at.as_tensor_variable(n)
m = at.as_tensor_variable(m)
k = at.as_tensor_variable(k)
i = at.switch(k >= 0, k, -k * m)
eye = at.zeros(n * m, dtype=dtype)
eye = at.set_subtensor(eye[i::m + 1], 1).reshape((n, m))
eye = at.set_subtensor(eye[m - k:, :], 0)
return eye Seems to do alright |
I'll make a pull request for this in a minute, I'm just fumbling around with git at the moment. |
Eye
doesn't have a C implementation, but adding one should be straightforward.The text was updated successfully, but these errors were encountered: