Skip to content

Commit

Permalink
Merge pull request #53 from rochisha0/add-sqrtm
Browse files Browse the repository at this point in the history
Add dispatcher for sqrtm
  • Loading branch information
Ericgig authored Jun 12, 2024
2 parents fe51b69 + 6e20aeb commit db0fb39
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/qutip_jax/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"conj_jaxdia",
"inv_jaxarray",
"expm_jaxarray",
"sqrtm_jaxarray",
"project_jaxarray",
]

Expand Down Expand Up @@ -113,6 +114,12 @@ def inv_jaxarray(matrix):
return JaxArray._fast_constructor(linalg.inv(matrix._jxa), matrix.shape)


def sqrtm_jaxarray(matrix):
"""Matrix square root `sqrt(A)` for a matrix `A`."""
_check_square_shape(matrix)
return JaxArray._fast_constructor(linalg.sqrtm(matrix._jxa), matrix.shape)


@jit
def _project_ket(array):
return array @ array.T.conj()
Expand Down Expand Up @@ -185,6 +192,11 @@ def project_jaxarray(state):
)


qutip.data.sqrtm.add_specialisations(
[(JaxArray, JaxArray, sqrtm_jaxarray),]
)


qutip.data.project.add_specialisations(
[
(JaxArray, JaxArray, project_jaxarray),
Expand Down
6 changes: 6 additions & 0 deletions tests/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ class TestInv(testing.TestInv):
specialisations = [pytest.param(_inv_jax, JaxArray, JaxArray)]


class TestSqrtm(testing.TestSqrtm):
specialisations = [
pytest.param(qutip_jax.sqrtm_jaxarray, JaxArray, JaxArray)
]


class TestProject(testing.TestProject):
specialisations = [
pytest.param(qutip_jax.project_jaxarray, JaxArray, JaxArray)
Expand Down

0 comments on commit db0fb39

Please sign in to comment.