-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jax.scipy.linalg.toeplitz
.
#13251
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just a few comments below and we can get this in.
Also, please add the new function to the list here so it will show up in the HTML docs: https://github.com/google/jax/blob/main/docs/jax.scipy.rst#jaxscipylinalg
Thanks for reviewing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
Looks like the numerical error is severe when TPUs/ GPUs are used because |
@yotarok You can get higher precision on TPU/GPU with |
I suspect |
The profiling results with TPU were:
Confirmed that with |
I noticed that integers are not tested. So, I added integer tests and some skip conditions because CuDNN doesn't support integer convolution except for int8. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last couple minor comments - thanks!
tests/linalg_test.py
Outdated
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)], | ||
rdtype=float_types + complex_types + int_types) | ||
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): | ||
if ((rdtype in [np.float64, np.complex128] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and throughout, please use two-space indentations.,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
What's the style recommendation for hanging indents?
For _src/scipy/linalg.py
, I followed Google-style, i.e. 2 spaces for normal indents, and 4 spaces for hanging indents.
However, tests/linalg_test.py
uses 2 spaces also for hanging indents, so I followed that file-local convention only for the test.
Thank you for the comments, I updated the PR accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Ping, since it's been while after the PR was approved, the branch is rebased on top of the current main. |
A counterpart for "scipy.linalg.toeplitz" is added in "jax.scipy.linalg".
We recently have a case that requires a Jax counterpart of
tf.linalg.LinearOperatorToeplitz
, and this functionality is in fact a part of SciPy, implemented asscipy.linalg.toeplitz
.This also relates to the request in #10144.