Skip to content

Commit

Permalink
Merge pull request #52 from rochisha0/enable-jit-eig
Browse files Browse the repository at this point in the history
fix _eigs_jaxarray to be compatible with jit
  • Loading branch information
Ericgig authored Jun 17, 2024
2 parents db0fb39 + 9485b94 commit bd4f704
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 44 deletions.
76 changes: 42 additions & 34 deletions src/qutip_jax/linalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax import jit
from jax import jit, lax
import jax.numpy as jnp
import numpy as np
from functools import partial
Expand All @@ -12,43 +12,31 @@
"eigs_jaxarray", "svd_jaxarray", "solve_jaxarray",
]

def herm_with_vecs(data):
evals, evecs = jnp.linalg.eigh(data)
evals, evecs = evals.astype(data.dtype), evecs.astype(data.dtype)
return evals, evecs

@partial(jit, static_argnums=[1, 2, 3, 4])
def _eigs_jaxarray(data, isherm, vecs, eigvals, low_first):
"""
Internal function to dispatch the eigenvalue solver to `eigh`, `eig`,
`eigvalsh` or `eigvals` based on the parameters.
"""
if isherm and vecs:
evals, evecs = jnp.linalg.eigh(data)
elif vecs:
evals, evecs = jnp.linalg.eig(data)
elif isherm:
evals = jnp.linalg.eigvalsh(data)
evecs = None
else:
evals = jnp.linalg.eigvals(data)
evecs = None

perm = jnp.argsort(evals.real)
evals = evals[perm]
if not low_first:
evals = evals[::-1]
evals = evals[:eigvals]
def nonherm_with_vecs(data):
evals, evecs = jnp.linalg.eig(data)
evals, evecs = evals.astype(data.dtype), evecs.astype(data.dtype)
return evals, evecs

if vecs:
evecs = evecs[:, perm]
if not low_first:
evecs = evecs[:, ::-1]
evecs = evecs[:, :eigvals]
def herm_no_vecs(data):
evals = jnp.linalg.eigvalsh(data)
evals = evals.astype(data.dtype)
return evals, None

return evals, evecs
def nonherm_no_vecs(data):
evals = jnp.linalg.eigvals(data)
evals = evals.astype(data.dtype)
return evals, None


# Can't jit it if we accept isherm=None
@partial(jit, static_argnums=[1, 2, 3, 4])
def eigs_jaxarray(data, isherm=None, vecs=True, sort='low', eigvals=0):
"""
Return eigenvalues and eigenvectors for a `Data` of type `"jax"`. Takes no
Return eigenvalues and eigenvectors for a `Data` of type `"jax"`. Takes no
special keyword arguments; see the primary documentation in :func:`.eigs`.
"""
N = data.shape[0]
Expand All @@ -59,11 +47,31 @@ def eigs_jaxarray(data, isherm=None, vecs=True, sort='low', eigvals=0):
if eigvals > N:
raise ValueError("Number of requested eigen vals/vecs must be <= N.")
eigvals = eigvals or N
# Let dict raise keyerror of
low_first = {"low": True, "high": False}[sort]
isherm = isherm if isherm is not None else bool(isherm_jaxarray(data))
isherm = isherm if isherm is not None else jnp.bool_(isherm_jaxarray(data))

evals, evecs = _eigs_jaxarray(data._jxa, isherm, vecs, eigvals, low_first)
if vecs:
evals, evecs = lax.cond(
isherm, herm_with_vecs,
nonherm_with_vecs, data._jxa
)
else:
evals, evecs = lax.cond(
isherm, herm_no_vecs,
nonherm_no_vecs, data._jxa
)

perm = jnp.argsort(evals.real)
evals = evals[perm]
if not low_first:
evals = evals[::-1]
evals = evals[:eigvals]

if vecs:
evecs = evecs[:, perm]
if not low_first:
evecs = evecs[:, ::-1]
evecs = evecs[:, :eigvals]

return (evals, JaxArray(evecs, copy=False)) if vecs else evals

Expand Down
18 changes: 8 additions & 10 deletions tests/test_eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ def test_eigen_known_oper():


@pytest.mark.parametrize(
["rand"],
["rand", "isherm"],
[
pytest.param(qutip.rand_herm, id="hermitian"),
pytest.param(qutip.rand_unitary, id="non-hermitian"),
pytest.param(qutip.rand_herm, True, id="hermitian"),
pytest.param(qutip.rand_unitary, None, id="non-hermitian"),
],
)
@pytest.mark.parametrize("order", ["low", "high"])
def test_eigen_rand_oper(rand, order):
def test_eigen_rand_oper(rand, isherm, order):
mat = rand(10, dtype="jax").data
isherm = rand is qutip.rand_herm
kw = {"isherm": isherm, "sort": order}
spvals, spvecs = qutip_jax.eigs_jaxarray(mat, vecs=True, **kw)
sp_energies = qutip_jax.eigs_jaxarray(mat, vecs=False, **kw)
Expand All @@ -42,17 +41,16 @@ def test_eigen_rand_oper(rand, order):


@pytest.mark.parametrize(
"rand",
["rand", "isherm"],
[
pytest.param(qutip.rand_herm, id="hermitian"),
pytest.param(qutip.rand_unitary, id="non-hermitian"),
pytest.param(qutip.rand_herm, True, id="hermitian"),
pytest.param(qutip.rand_unitary, None, id="non-hermitian"),
],
)
@pytest.mark.parametrize("order", ["low", "high"])
@pytest.mark.parametrize("N", [1, 5, 8, 9])
def test_eigvals_parameter(rand, order, N):
def test_eigvals_parameter(rand, isherm, order, N):
mat = rand(10, dtype="jax").data
isherm = rand is qutip.rand_herm
kw = {"isherm": isherm, "sort": order}
spvals, spvecs = qutip_jax.eigs_jaxarray(mat, vecs=True, eigvals=N, **kw)
sp_energies = qutip_jax.eigs_jaxarray(mat, vecs=False, eigvals=N, **kw)
Expand Down

0 comments on commit bd4f704

Please sign in to comment.