Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow creating a prod op with a qfunc #4011

Merged
merged 12 commits into from
May 3, 2023
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@
* Update various Operators and templates to ensure their decompositions only return lists of Operators.
[(#3243)](https://github.com/PennyLaneAI/pennylane/pull/3243)

* `qml.prod` now accepts a single qfunc input for creating new `Prod` operators.
[(#4011)](https://github.com/PennyLaneAI/pennylane/pull/4011)

<h3>Breaking changes 💔</h3>

* Both JIT interfaces are not compatible with JAX `>0.4.3`, we raise an error for those versions.
Expand Down
28 changes: 24 additions & 4 deletions pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
"""
import itertools
from copy import copy
from functools import reduce
from functools import reduce, wraps
from itertools import combinations
from typing import List, Tuple, Union

import numpy as np
from scipy.sparse import kron as sparse_kron

import pennylane as qml
Expand All @@ -33,6 +32,7 @@
from pennylane.ops.qubit import Hamiltonian
from pennylane.ops.qubit.non_parametric_ops import PauliX, PauliY, PauliZ
from pennylane.queuing import QueuingManager
from pennylane.typing import TensorLike
from pennylane.wires import Wires

from .composite import CompositeOp
Expand All @@ -51,7 +51,8 @@ def prod(*ops, do_queue=True, id=None, lazy=True):
that the given operators act on.

Args:
ops (tuple[~.operation.Operator]): The operators we would like to multiply
ops (Union[tuple[~.operation.Operator], Callable]): The operators we would like to multiply.
Alternatively, a single qfunc that queues operators can be passed to this function.

Keyword Args:
do_queue (bool): determines if the product operator will be queued. Default is True.
Expand Down Expand Up @@ -84,7 +85,26 @@ def prod(*ops, do_queue=True, id=None, lazy=True):
>>> prod_op.matrix()
array([[ 0, -1],
[ 1, 0]])

You can also create a prod operator by passing a qfunc to prod, like the following:

>>> def qfunc():
... qml.Hadamard(0)
... qml.CNOT([0, 1])
>>> prod_op = prod(qfunc)
>>> prod_op
CNOT(wires=[0, 1]) @ Hadamard(wires=[0])
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(ops) == 1 and callable(ops[0]):
fn = ops[0]

@wraps(fn)
def wrapper(*args, **kwargs):
qs = qml.tape.make_qscript(fn)(*args, **kwargs)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
return prod(*qs.operations[::-1], do_queue=do_queue, id=id, lazy=lazy)

return wrapper

if lazy:
return Prod(*ops, do_queue=do_queue, id=id)

Expand Down Expand Up @@ -257,7 +277,7 @@ def decomposition(self):
def matrix(self, wire_order=None):
"""Representation of the operator as a matrix in the computational basis."""

mats: List[np.ndarray] = [] # TODO: change type to `tensor_like` when available
mats: List[TensorLike] = []
batched: List[bool] = [] # batched[i] tells if mats[i] is batched or not
for ops in self.overlapping_ops:
gen = (
Expand Down
64 changes: 64 additions & 0 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,70 @@ def test_has_diagonalizing_gates_false_via_factor(self):
prod_op = prod(MyOp(3.1, 0), qml.PauliX(2), do_queue=True)
assert prod_op.has_diagonalizing_gates is False

def test_qfunc_init(self):
"""Tests prod initialization with a qfunc argument."""

def qfunc():
qml.Hadamard(0)
qml.CNOT([0, 1])
qml.RZ(1.1, 1)

prod_gen = prod(qfunc)
assert callable(prod_gen)
prod_op = prod_gen()
expected = prod(qml.RZ(1.1, 1), qml.CNOT([0, 1]), qml.Hadamard(0))
assert qml.equal(prod_op, expected)
assert prod_op.wires == Wires([1, 0])

def test_qfunc_init_accepts_args_kwargs(self):
"""Tests that prod preserves args when wrapping qfuncs."""

def qfunc(x, run_had=False):
if run_had:
qml.Hadamard(0)
qml.RX(x, 0)
qml.CNOT([0, 1])

prod_gen = prod(qfunc)
assert qml.equal(prod_gen(1.1), prod(qml.CNOT([0, 1]), qml.RX(1.1, 0)))
assert qml.equal(
prod_gen(2.2, run_had=True), prod(qml.CNOT([0, 1]), qml.RX(2.2, 0), qml.Hadamard(0))
)

def test_qfunc_init_propagates_Prod_kwargs(self):
"""Tests that additional kwargs for Prod are propagated using qfunc initialization."""

def qfunc(x):
qml.prod(qml.RX(x, 0), qml.PauliZ(1))
qml.CNOT([0, 1])

prod_gen = prod(qfunc, do_queue=False, id=123987, lazy=False)

with qml.queuing.AnnotatedQueue() as q:
prod_op = prod_gen(1.1)

assert prod_op not in q # do_queue worked
assert prod_op.id == 123987 # id was set
assert qml.equal(prod_op, prod(qml.CNOT([0, 1]), qml.PauliZ(1), qml.RX(1.1, 0))) # eager

def test_qfunc_init_only_works_with_one_qfunc(self):
"""Test that the qfunc init only occurs when one callable is passed to prod."""

def qfunc():
qml.Hadamard(0)
qml.CNOT([0, 1])

prod_op = prod(qfunc)()
assert qml.equal(prod_op, prod(qml.CNOT([0, 1]), qml.Hadamard(0)))

def fn2():
qml.PauliX(0)
qml.PauliY(1)

for args in [(qfunc, fn2), (qfunc, qml.PauliX), (qml.PauliX, qfunc)]:
with pytest.raises(AttributeError, match="has no attribute 'wires'"):
prod(*args)


class TestMatrix:
"""Test matrix-related methods."""
Expand Down