Skip to content

Commit

Permalink
Merge pull request #50 from Sampreet/feature-mcsolve-measure
Browse files Browse the repository at this point in the history
Add mcsolve support
  • Loading branch information
Ericgig authored Jun 6, 2024
2 parents 336f62f + e537e28 commit fe51b69
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doc/source/basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Jax in QuTiP
Basic usage
===========

In orther to enable qutip-jax, it is just necessary to import the module. Once imported, ``qutip.Qobj``'s data can be represented as a JAX array. Furthermore, diffrax ODE will be available as an option for qutip's solvers (``sesolve``, ``mcsolve``, etc.).
In order to enable qutip-jax, it is just necessary to import the module. Once imported, ``qutip.Qobj``'s data can be represented as a JAX array. Furthermore, diffrax ODE will be available as an option for qutip's solvers (``sesolve``, ``mcsolve``, etc.).
None of the functions in the module are expected to be used directly. Instead, they will be used by qutip, allowing the user to interact only with the already familiar QuTiP interface.

There are many ways to create a QuTiP ``Qobj`` backed by JAX's array class.
Expand Down
77 changes: 77 additions & 0 deletions doc/source/solver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,80 @@ The following code shows an example of how to use JAX:
result = qutip.mesolve(H, rho0, [0, 1], c_ops=c_ops, options={"method": "diffrax"})
Note that while running the above code on the GPU, the ``"normalize_output"`` option should be set to ``False``, as Schur decomposition is only supported in the CPU currently.


.. _mcsolve:

Using Jax in ``mcsolve``
========================

Similar to ``mesolve``, the JAX backend can be used with ``mcsolve`` to simulate Monte Carlo quantum trajectories, by defining the operators and states as ``jax`` or ``jaxdia`` dtypes and using a JAX-based ODE integrator (currently, ``qutip-jax`` supports a ``diffrax``-based integrator, ``DiffraxIntegrator``).

The following code demonstrates the evolution of :math:`10` trajectories with ``mcsolve`` for the two-level system described in `QuTiP's Monte Carlo Solver tutorial <https://qutip.readthedocs.io/en/latest/guide/dynamics/dynamics-monte.html>`_ with a Hilbert space dimension of :math:`N = 10^4` for the cavity mode:

.. code-block::
import jax.numpy as jnp
import qutip
import qutip_jax
N = 10000
tlist = jnp.linspace(0.0, 10.0, 200)
# ``jaxdia`` operators support higher dimensional Hilbert spaces in the GPU
with qutip.CoreOptions(default_dtype="jaxdia"):
a = qutip.tensor(qutip.qeye(2), qutip.destroy(N))
sm = qutip.tensor(qutip.destroy(2), qutip.qeye(N))
H = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm + 2.0 * jnp.pi * 0.25 * (sm * a.dag() + sm.dag() * a)
# using ``jax`` dtype since ``DiffraxIntegrator`` anyway converts the final state to ``jax``
state = qutip.tensor(qutip.fock(2, 0, dtype="jax"), qutip.fock(N, 8, dtype="jax"))
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]
result = qutip.mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={
"method": "diffrax"
})
The default solver for ``DiffraxIntegrator`` is ``diffrax.Tsit5()`` with an adaptive step-size controller (``diffrax.PIDController()``) using QuTiP's default tolerances (``atol = 1e-8``, ``rtol = 1e-6``).
To use a different solver or step-size controller (see `Diffrax ODE Solvers <https://docs.kidger.site/diffrax/api/solvers/ode_solvers/>`_ and `Diffrax Step Size Controllers <https://docs.kidger.site/diffrax/api/stepsize_controller/>`_ for available options), the following options can be passed alongside ``"method": "diffrax"``:

.. code-block::
from diffrax import Dopri5, ConstantStepSize
options = dict(
method = "diffrax",
solver = Dopri5(),
stepsize_controller = ConstantStepSize(),
dt0 = 0.001
)
Note that the coefficient function of a time-dependent Hamiltonian needs to be jit-wrapped before it is passed to the solver. An example snippet for a coefficient with additional arguments is given below:

.. code-block::
from functools import partial
import jax
@partial(jax.jit, static_argnames=("omega", ))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a
H = [H_0, [H_1_op, H_1_coeff]]
result = qutip.mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={
"method": "diffrax"
}, args={
"omega": 1.0 # arguments for the coefficient function are passed here
})
Alternatively, the ``JaxJitCoeff`` class can be utilized as demonstrated by the following snippet:

.. code-block::
from qutip_jax.qobjevo import JaxJitCoeff
H = [H_0, [H_1_op, JaxJitCoeff(lambda t, omega: 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t), args={
"omega": 1.0 # arguments for the coefficient function are passed here
}, static_argnames=("omega", ))]]
result = qutip.mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={
"method": "diffrax"
})
4 changes: 2 additions & 2 deletions src/qutip_jax/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from qutip.solver.integrator import Integrator
import jax
import jax.numpy as jnp
from qutip.solver.mcsolve import MCSolver
from qutip.solver.mesolve import MESolver
from qutip.solver.sesolve import SESolver
from qutip.core import data as _data
import numpy as np
from qutip_jax import JaxArray
from qutip_jax.qobjevo import JaxQobjEvo

Expand Down Expand Up @@ -135,7 +135,7 @@ def options(self):
def options(self, new_options):
Integrator.options.fset(self, new_options)


MCSolver.add_integrator(DiffraxIntegrator, "diffrax")
MESolver.add_integrator(DiffraxIntegrator, "diffrax")
SESolver.add_integrator(DiffraxIntegrator, "diffrax")
jax.tree_util.register_pytree_node(
Expand Down
4 changes: 3 additions & 1 deletion src/qutip_jax/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def column_unstack_jaxarray(matrix, rows):


@jit
def split_columns_jaxarray(matrix):
def split_columns_jaxarray(matrix, copy=None):
# `copy` is passed by some `Qobj` methods
# but JaxArray always creates a new array.
return [
JaxArray(matrix._jxa[:, k]) for k in range(matrix.shape[1])
]
Expand Down

0 comments on commit fe51b69

Please sign in to comment.