Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PLxPR implementation for qml.transforms.cancel_inverses #6636

Draft
wants to merge 3 commits into
base: plxpr-transform-base
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion pennylane/transforms/optimization/cancel_inverses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
"""Transform for cancelling adjacent inverse gates in quantum circuits."""
# pylint: disable=too-many-branches
from functools import partial

import pennylane as qml
from pennylane.ops.op_math import Adjoint
from pennylane.ops.qubit.attributes import (
self_inverses,
Expand Down Expand Up @@ -63,7 +65,69 @@ def _are_inverses(op1, op2):
return False


@transform
def _cancel_inverses_plxpr_transform(
primitive, tracers, params, targs, tkwargs, state
): # pylint: disable=unused-argument,too-many-arguments,too-many-positional-arguments
"""Implementation for applying ``cancel_inverses`` to PLxPR."""
# pylint: disable=import-outside-toplevel
from pennylane.capture import TransformTracer

previous_ops = state[0].get("previous_ops", {})

with qml.QueuingManager.stop_recording():
tracers_in = [t.val if isinstance(t.val, qml.operation.Operator) else t for t in tracers]
cur_op = primitive.impl(*tracers_in, **params)

# if isinstance(cur_op, qml.measurements.MeasurementProcess):
# tracers = [TransformTracer(t._trace, t.val, t.idx + 1) for t in tracers]
# return primitive.bind(*tracers, **params)

prev_op = previous_ops.pop(cur_op.wires[0], None)
if prev_op is None:
# No operator to compare against, so we save the current op and don't bind anything
for w in cur_op.wires:
previous_ops[w] = cur_op
state[0]["cancel_inverses_previous_ops"] = previous_ops
return []

cancel = False
if _are_inverses(cur_op, prev_op):
if cur_op.wires == prev_op.wires:
# Inverse ops, same wires. Cancel
cancel = True
if cur_op in symmetric_over_all_wires and len(
Wires.shared_wires([cur_op.wires, prev_op.wires])
) == len(cur_op.wires):
# Symmetric ops, full overlap in wires. Cancel
cancel = True
if cur_op in symmetric_over_control_wires and (
len(Wires.shared_wires([cur_op.wires[:-1], prev_op.wires[:-1]]))
== len(cur_op.wires) - 1
):
# Is controlled op that is symmetric over control wires, and control
# wires have full overlap with prvious op's control wires
return

if cancel:
# Update dictionary and don't bind any new primitives
for w in prev_op.wires:
previous_ops.pop(w)
state["cancel_inverses_previous_ops"] = previous_ops
return None

# Can't cancel. Add current op to dictionary and
for w in prev_op.wires:
previous_ops.pop(w)
for w in cur_op.wires:
previous_ops[w] = cur_op

# pylint: disable=protected-access
tracers = [TransformTracer(t._trace, t.val, t.idx + 1) for t in tracers]
# FIXME: This needs to bind the previous primitive, not the current one
return primitive.bind(*tracers, **params)


@partial(transform, plxpr_transform=_cancel_inverses_plxpr_transform)
def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
"""Quantum function transform to remove any operations that are applied next to their
(self-)inverses or adjoint.
Expand Down