Skip to content

Commit

Permalink
Merge pull request #65 from rochisha0/mcd
Browse files Browse the repository at this point in the history
create autodiff doc for mcsolve
  • Loading branch information
Ericgig authored Aug 29, 2024
2 parents 0cc6256 + 1909221 commit b7ff6de
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
65 changes: 64 additions & 1 deletion doc/source/autodiff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,67 @@ should work:
result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
return result.e_data[0][1].real
jax.grad(f)(0.5, solver)
jax.grad(f)(0.5, solver)
Auto differentiation in ``mcsolve``
===================================


Here is an example to use jax auto differentiation with `mcsolve`.
The automatic differentiation (`jax.grad`) in `mcsolve` does not support parallel map operations.
To ensure accurate gradient computations, please use the default serial execution instead of
parallel mapping within `mcsolve`.


.. code-block::
import qutip_jax
import qutip
import jax
import jax.numpy as jnp
from functools import partial
from qutip import mcsolve
# Use JAX backend for QuTiP
qutip_jax.set_as_default()
# Define time-dependent functions
@partial(jax.jit, static_argnames=("omega",))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)
# Define operators and states
size = 10
a = qutip.tensor(qutip.destroy(size), qutip.qeye(2)).to('jaxdia') # Annihilation operator
sm = qutip.qeye(size).to('jaxdia') & qutip.sigmax().to('jaxdia') # Example spin operator
# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a
# Initialize the Hamiltonian with time-dependent coefficients
H = [H_0, [H_1_op, qutip.coefficient(H_1_coeff, args={"omega": 1.0})]]
# Define initial states, mixed states are not supported
state = qutip.basis(size, size - 1).to('jax') & qutip.basis(2, 1).to('jax')
# Define collapse operators and observables
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]
# Time list
tlist = jnp.linspace(0.0, 10.0, 101)
# Define the function for which we want to compute the gradient
def f(omega):
result = mcsolve(
H, state, tlist, c_ops, e_ops, ntraj=10,
args={"omega": omega},
options={"method": "diffrax"}
)
# Return the expectation value of the number operator at the final time
return result.expect[0][-1].real
# Compute the gradient
gradient = jax.grad(f)(1.0)
12 changes: 6 additions & 6 deletions doc/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To enable JAX as the backend for QuTiP, you need to set the backend to `jax` usi
import qutip_jax
# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()
Using `jax.jit` with QuTiP
--------------------------
Expand All @@ -35,7 +35,7 @@ Using `jax.jit` with QuTiP
import qutip_jax
# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()
# Define states
psi = basis(2, 0).to("jax")
Expand All @@ -57,7 +57,7 @@ Using `jax.jit` with QuTiP
import qutip_jax
# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()
# Define a density matrix
rho = ket2dm(psi).to("jax")
Expand Down Expand Up @@ -87,7 +87,7 @@ To compute the gradient, you need a function that returns a scalar. Note that `j
import qutip_jax
# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()
# Define bra and ket states
bra_state = basis(2, 0).dag()
Expand Down Expand Up @@ -119,7 +119,7 @@ The `trace_dist` function supports `oper` states for gradient computation.
import qutip_jax
# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()
# Define an operator state
oper_state = rand_dm(2)
Expand Down Expand Up @@ -147,5 +147,5 @@ If you want to switch back to the default backend (NumPy), use the following:

.. code-block:: python
qutip.settings.core["numpy_backend"] = np
qutip_jax.set_as_default(revert = True)

0 comments on commit b7ff6de

Please sign in to comment.