Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax.scipy.linalg.toeplitz. #13251

Merged
merged 1 commit into from
Dec 8, 2022
Merged

Add jax.scipy.linalg.toeplitz. #13251

merged 1 commit into from
Dec 8, 2022

Conversation

yotarok
Copy link
Contributor

@yotarok yotarok commented Nov 15, 2022

A counterpart for "scipy.linalg.toeplitz" is added in "jax.scipy.linalg".

We recently have a case that requires a Jax counterpart of tf.linalg.LinearOperatorToeplitz, and this functionality is in fact a part of SciPy, implemented as scipy.linalg.toeplitz.
This also relates to the request in #10144.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a few comments below and we can get this in.

Also, please add the new function to the list here so it will show up in the HTML docs: https://github.com/google/jax/blob/main/docs/jax.scipy.rst#jaxscipylinalg

tests/linalg_test.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Show resolved Hide resolved
jax/_src/scipy/linalg.py Show resolved Hide resolved
jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
tests/linalg_test.py Outdated Show resolved Hide resolved
@yotarok
Copy link
Contributor Author

yotarok commented Nov 17, 2022

Thanks for reviewing.
All the issues are resolved and the new version is force-pushed to the branch.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 17, 2022
@yotarok
Copy link
Contributor Author

yotarok commented Nov 17, 2022

Looks like the numerical error is severe when TPUs/ GPUs are used because lax.conv_general_dilated_patches is used.
So, I switched to use vmap and np.roll that are slightly slower but should be accurate.

@hawkinsp
Copy link
Collaborator

@yotarok You can get higher precision on TPU/GPU with jax.default_matmul_precision("float32") which can be used as, e.g., a decorator, or the precision option to the convolution.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 17, 2022

I suspect vmap of roll is going to be much less efficient than the original approach, due to how it's implemented. Would you like to change back to the convolution-based approach along with adjusting the matmul precision?

@yotarok
Copy link
Contributor Author

yotarok commented Nov 18, 2022

The profiling results with TPU were:

  • lax.conv_general_dilated_patches: ~9ms
  • roll + vmap: ~15ms
  • lax.conv_general_dilated_patches + Precision.HIGHEST: ~12ms
    (with configuaration of r == c, r.shape = (1024, 256) and toeplitz is called with vmap).

Confirmed that with Precision.HIGHEST, assertAllClose in the test passes with the default tolerance parameters with TPU/ GPU.

@yotarok
Copy link
Contributor Author

yotarok commented Nov 18, 2022

I noticed that integers are not tested. So, I added integer tests and some skip conditions because CuDNN doesn't support integer convolution except for int8.
PTAL again.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last couple minor comments - thanks!

tests/linalg_test.py Outdated Show resolved Hide resolved
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
rdtype=float_types + complex_types + int_types)
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
if ((rdtype in [np.float64, np.complex128]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and throughout, please use two-space indentations.,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.
What's the style recommendation for hanging indents?
For _src/scipy/linalg.py, I followed Google-style, i.e. 2 spaces for normal indents, and 4 spaces for hanging indents.
However, tests/linalg_test.py uses 2 spaces also for hanging indents, so I followed that file-local convention only for the test.

@yotarok
Copy link
Contributor Author

yotarok commented Nov 30, 2022

Thank you for the comments, I updated the PR accordingly.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@yotarok
Copy link
Contributor Author

yotarok commented Dec 8, 2022

Ping, since it's been while after the PR was approved, the branch is rebased on top of the current main.

@copybara-service copybara-service bot merged commit 02ba16e into jax-ml:main Dec 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants