Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Calculating the matrix of a fractional power during jitting #6705

Open
1 task done
albi3ro opened this issue Dec 11, 2024 · 2 comments
Open
1 task done

[BUG] Calculating the matrix of a fractional power during jitting #6705

albi3ro opened this issue Dec 11, 2024 · 2 comments
Labels
bug 🐛 Something isn't working

Comments

@albi3ro
Copy link
Contributor

albi3ro commented Dec 11, 2024

Expected behavior

If we are jitting and raising an operation to a fractional power, we should say that it doesn't have a matrix

Actual behavior

We say that we have a matrix, but then get an error due to using a scipy routine on a jax array.

Additional information

No response

Source code

@qml.qnode(qml.device('default.qubit', shots=5, wires=3))
def circuit(x):
    op = qml.RX(x,0) ** 2.0
    print(op.has_matrix)
    print(op.matrix())
    return qml.probs(wires=0)

jax.jit(circuit)(0.5)

Tracebacks

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[5], line 8
      5     print(op.matrix())
      6     return qml.probs(wires=0)
----> 8 jax.jit(circuit)(0.5)

    [... skipping hidden 11 frame]

File ~/Prog/pennylane/pennylane/workflow/qnode.py:953, in QNode.__call__(self, *args, **kwargs)
    951 if qml.capture.enabled():
    952     return capture_qnode(self, *args, **kwargs)
--> 953 return self._impl_call(*args, **kwargs)

File ~/Prog/pennylane/pennylane/workflow/qnode.py:927, in QNode._impl_call(self, *args, **kwargs)
    924 def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
    925 
    926     # construct the tape
--> 927     self.construct(args, kwargs)
    929     old_interface = self.interface
    930     if old_interface == "auto":

File ~/Prog/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Prog/pennylane/pennylane/workflow/qnode.py:859, in QNode.construct(self, args, kwargs)
    857 with pldb_device_manager(self.device):
    858     with qml.queuing.AnnotatedQueue() as q:
--> 859         self._qfunc_output = self.func(*args, **kwargs)
    861 self._tape = QuantumScript.from_queue(q, shots)
    863 params = self._tape.get_parameters(trainable_only=False)

Cell In[5], line 5, in circuit(x)
      3 op = qml.RX(x,0) ** 2.0
      4 print(op.has_matrix)
----> 5 print(op.matrix())
      6 return qml.probs(wires=0)

File ~/Prog/pennylane/pennylane/ops/op_math/composite.py:37, in handle_recursion_error.<locals>.wrapper(*args, **kwargs)
     34 @wraps(func)
     35 def wrapper(*args, **kwargs):
     36     try:
---> 37         return func(*args, **kwargs)
     38     except RecursionError as e:
     39         raise RuntimeError(
     40             "Maximum recursion depth reached! This is likely due to nesting too many levels "
     41             "of composite operators. Try setting lazy=False when calling qml.sum, qml.prod, "
     42             "and qml.s_prod, or use the +, @, and * operators instead. Alternatively, you "
     43             "can periodically call qml.simplify on your operators."
     44         ) from e

File ~/Prog/pennylane/pennylane/ops/op_math/symbolicop.py:289, in ScalarSymbolicOp.matrix(self, wire_order)
    286     mat = qml.math.stack([self._matrix(scalar, ar2) for ar2 in base_matrix])
    287 else:
    288     # none are broadcasted
--> 289     mat = self._matrix(scalar, base_matrix)
    291 return qml.math.expand_matrix(mat, wires=self.wires, wire_order=wire_order)

File ~/Prog/pennylane/pennylane/ops/op_math/pow.py:247, in Pow._matrix(scalar, mat)
    244         out @= mat
    245     return out
--> 247 return fractional_matrix_power(mat, scalar)

File ~/Prog/pl/lib/python3.12/site-packages/scipy/linalg/_matfuncs.py:137, in fractional_matrix_power(A, t)
     97 """
     98 Compute the fractional power of a matrix.
     99 
   (...)
    133 
    134 """
    135 # This fixes some issue with imports;
    136 # this function calls onenormest which is in scipy.sparse.
--> 137 A = _asarray_square(A)
    138 import scipy.linalg._matfuncs_inv_ssq
    139 return scipy.linalg._matfuncs_inv_ssq._fractional_matrix_power(A, t)

File ~/Prog/pl/lib/python3.12/site-packages/scipy/linalg/_matfuncs.py:52, in _asarray_square(A)
     34 def _asarray_square(A):
     35     """
     36     Wraps asarray with the extra requirement that the input be a square matrix.
     37 
   (...)
     50 
     51     """
---> 52     A = np.asarray(A)
     53     if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
     54         raise ValueError('expected square array_like input')

File ~/Prog/pl/lib/python3.12/site-packages/jax/_src/core.py:650, in Tracer.__array__(self, *args, **kw)
    649 def __array__(self, *args, **kw):
--> 650   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[2,2].
The error occurred while tracing the function circuit at /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_35325/3603183953.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

System information

master

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@albi3ro albi3ro added the bug 🐛 Something isn't working label Dec 11, 2024
@tonmoy-b
Copy link
Contributor

Hi @albi3ro ,

     Can I work on this task? If so kindly assign it to me and I will start working on it right away. 

Thanks,
Tonmoy

@che-burashco
Copy link

che-burashco commented Dec 31, 2024

Quick question, would it make sense to do something like this instead. This would actually allow for a consistent output no matter if the function is jitted. Based on this thread in JAX, it's not a good idea to have the results diverging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants