Skip to content

Allow transposed argument in linalg.solve #1231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 4, 2025

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Feb 21, 2025

Description

Adds a transposed argument to pt.linalg.solve. The signature now matches scipy.linalg.solve 1-to-1. It also fixes gradients for pt.linalg.solve_triangular when trans is either "T" or "C" (though we don't really support complex inputs anyway)

It also adds a bunch of tests for the two implicated solve Ops.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1231.org.readthedocs.build/en/1231/

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 22, 2025

question does this matter for PyTensor? what would we lose if the transpose flag was just used symbolically but not part of the Op? Say if the user calls solve(transposed=True) we transpose the input and call the old op without the transpose flag?

transpose should be pretty cheap and if you have 2 in the grad they get rewritten away anyway.

Is this not a concern only for eager libraries like scipy? I wouldn't be surprised if Jax also does it just symbolically

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Feb 22, 2025

I have no objection to going that way instead. It might be smarter -- It seems like the transposed flag might be a bit slower in pure numpy:

from scipy import linalg
import numpy as np
rng = np.random.default_rng()

A = rng.normal(size=(10_000, 10_000))
b = rng.normal(size=(10_000,))

def f1(A, b):
    return linalg.solve(A.T, b)

def f2(A, b):
    return linalg.solve(A, b, transposed=True)

%timeit f1(A, b)
# 1.84 s ± 107 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f2(A, b)
# 2.37 s ± 82.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Plus we have to transpose the output of the gradients when transposed=True anyway, so there's still a transpose popping up somewhere. I'll make the change to just have it quietly transpose the inputs for you.

The bottom line is that I'd like for the argument to be there so we match the scipy API.

@ricardoV94
Copy link
Member

The bottom line is that I'd like for the argument to be there so we match the scipy API.

Yeah we should always try to match the external API

@jessegrabowski jessegrabowski force-pushed the solve-transpose branch 2 times, most recently from fde3e37 to 1d5c020 Compare February 28, 2025 11:53
@ricardoV94
Copy link
Member

Are all the test changes needed? At least the transpose in JAX seems unnecessary since it's not part of the op

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 1, 2025

You're testing 36 combinations with parametrize just for solve in jax, transpiling both python and JAX and jitting/evaluating the latter. Doesn't seem reasonable

@jessegrabowski
Copy link
Member Author

But it's fast

@ricardoV94
Copy link
Member

But it's fast

That's true. Maybe you should add dtype parametrization to validate my concerns :)

@jessegrabowski
Copy link
Member Author

brb adding those tests

In other news, why is jax pad test failing? Is it flakey?

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 1, 2025

Hmm I saw it failing in my other PR and I thought it was my changes. In that case either something broke in a recent jax release or a commit we merged into main is causing it to fail

@maresb
Copy link
Contributor

maresb commented Mar 1, 2025

Jax 0.4.36 just got released on conda-forge about 9 hours ago, so maybe something changed with that. You could test if pinning jax <0.4.36 works around the issue. conda-forge/jax-feedstock#169

@ricardoV94
Copy link
Member

I tested locally jax had a bug in 0.4.36 that seems to have disappeared by 0.5.1. I didn't bother checking exactly when did it stop failing. Added a commit in another PR that will skip it: a0d9ecf

@ricardoV94
Copy link
Member

Nope it's still failing: jax-ml/jax#26888

@ricardoV94
Copy link
Member

you can rebase @jessegrabowski

@ricardoV94 ricardoV94 changed the title Allow transposed argument in pt.linalg.solve Allow transposed argument in linalg.solve Mar 3, 2025
Copy link

codecov bot commented Mar 4, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.99%. Comparing base (757a10c) to head (0f26954).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1231   +/-   ##
=======================================
  Coverage   81.99%   81.99%           
=======================================
  Files         188      188           
  Lines       48600    48608    +8     
  Branches     8685     8688    +3     
=======================================
+ Hits        39849    39857    +8     
  Misses       6586     6586           
  Partials     2165     2165           
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/slinalg.py 95.00% <ø> (-0.13%) ⬇️
pytensor/link/numba/dispatch/slinalg.py 44.42% <ø> (-0.11%) ⬇️
pytensor/tensor/slinalg.py 93.67% <100.00%> (+0.13%) ⬆️

@jessegrabowski jessegrabowski merged commit bf628c9 into pymc-devs:main Mar 4, 2025
73 checks passed
@jessegrabowski jessegrabowski deleted the solve-transpose branch March 4, 2025 10:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add transposed argument to pt.linalg.solve
3 participants