-
Notifications
You must be signed in to change notification settings - Fork 132
Allow transposed
argument in linalg.solve
#1231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
question does this matter for PyTensor? what would we lose if the transpose flag was just used symbolically but not part of the Op? Say if the user calls solve(transposed=True) we transpose the input and call the old op without the transpose flag? transpose should be pretty cheap and if you have 2 in the grad they get rewritten away anyway. Is this not a concern only for eager libraries like scipy? I wouldn't be surprised if Jax also does it just symbolically |
I have no objection to going that way instead. It might be smarter -- It seems like the from scipy import linalg
import numpy as np
rng = np.random.default_rng()
A = rng.normal(size=(10_000, 10_000))
b = rng.normal(size=(10_000,))
def f1(A, b):
return linalg.solve(A.T, b)
def f2(A, b):
return linalg.solve(A, b, transposed=True)
%timeit f1(A, b)
# 1.84 s ± 107 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f2(A, b)
# 2.37 s ± 82.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) Plus we have to transpose the output of the gradients when The bottom line is that I'd like for the argument to be there so we match the scipy API. |
Yeah we should always try to match the external API |
fde3e37
to
1d5c020
Compare
Are all the test changes needed? At least the transpose in JAX seems unnecessary since it's not part of the op |
You're testing 36 combinations with parametrize just for solve in jax, transpiling both python and JAX and jitting/evaluating the latter. Doesn't seem reasonable |
But it's fast |
That's true. Maybe you should add dtype parametrization to validate my concerns :) |
brb adding those tests In other news, why is jax pad test failing? Is it flakey? |
Hmm I saw it failing in my other PR and I thought it was my changes. In that case either something broke in a recent jax release or a commit we merged into main is causing it to fail |
Jax 0.4.36 just got released on conda-forge about 9 hours ago, so maybe something changed with that. You could test if pinning |
I tested locally jax had a bug in 0.4.36 that seems to have disappeared by 0.5.1. I didn't bother checking exactly when did it stop failing. Added a commit in another PR that will skip it: a0d9ecf |
Nope it's still failing: jax-ml/jax#26888 |
you can rebase @jessegrabowski |
c529087
to
eeaf7dd
Compare
transposed
argument in pt.linalg.solve
transposed
argument in linalg.solve
4b6e2ad
to
0f26954
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1231 +/- ##
=======================================
Coverage 81.99% 81.99%
=======================================
Files 188 188
Lines 48600 48608 +8
Branches 8685 8688 +3
=======================================
+ Hits 39849 39857 +8
Misses 6586 6586
Partials 2165 2165
|
Description
Adds a
transposed
argument topt.linalg.solve
. The signature now matchesscipy.linalg.solve
1-to-1. It also fixes gradients forpt.linalg.solve_triangular
whentrans
is either"T"
or"C"
(though we don't really support complex inputs anyway)It also adds a bunch of tests for the two implicated solve Ops.
Related Issue
transposed
argument topt.linalg.solve
#1229Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1231.org.readthedocs.build/en/1231/