diff --git a/src/qutip_jax/reshape.py b/src/qutip_jax/reshape.py index d723f31..3f4a50e 100644 --- a/src/qutip_jax/reshape.py +++ b/src/qutip_jax/reshape.py @@ -4,7 +4,6 @@ from . import JaxArray import qutip from functools import partial -import numpy as np __all__ = [ "reshape_jaxarray", @@ -64,14 +63,14 @@ def split_columns_jaxarray(matrix, copy=None): def _parse_ptrace_inputs(dims, sel, shape): - dims = np.atleast_1d(dims).ravel() - sel = np.atleast_1d(sel) + dims = jnp.atleast_1d(dims).ravel() + sel = jnp.atleast_1d(sel) sel.sort() if shape[0] != shape[1]: raise ValueError("ptrace is only defined for square density matrices") - if shape[0] != np.prod(dims, dtype=int): + if shape[0] != jnp.prod(dims, dtype=int): raise ValueError( f"the input matrix shape, {shape} and the" f" dimension argument, {dims}, are not compatible." @@ -110,11 +109,11 @@ def ptrace_jaxarray(matrix, dims, sel): nd = dims.shape[0] dims2 = tuple(list(dims) * 2) sel = list(sel) - qtrace = list(set(np.arange(nd)) - set(sel)) + qtrace = list(set(jnp.arange(nd)) - set(sel)) - dkeep = np.prod([dims[x] for x in sel], dtype=int) - dtrace = np.prod([dims[x] for x in qtrace], dtype=int) + dkeep = jnp.prod([dims[x] for x in sel], dtype=int) + dtrace = jnp.prod([dims[x] for x in qtrace], dtype=int) transpose_idx = tuple( qtrace + [nd + q for q in qtrace]