diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 7a6e35e7c8e..be3550dae79 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -13,7 +13,9 @@ # limitations under the License. """Transform for cancelling adjacent inverse gates in quantum circuits.""" # pylint: disable=too-many-branches +from functools import partial +import pennylane as qml from pennylane.ops.op_math import Adjoint from pennylane.ops.qubit.attributes import ( self_inverses, @@ -63,7 +65,69 @@ def _are_inverses(op1, op2): return False -@transform +def _cancel_inverses_plxpr_transform( + primitive, tracers, params, targs, tkwargs, state +): # pylint: disable=unused-argument,too-many-arguments,too-many-positional-arguments + """Implementation for applying ``cancel_inverses`` to PLxPR.""" + # pylint: disable=import-outside-toplevel + from pennylane.capture import TransformTracer + + previous_ops = state[0].get("previous_ops", {}) + + with qml.QueuingManager.stop_recording(): + tracers_in = [t.val if isinstance(t.val, qml.operation.Operator) else t for t in tracers] + cur_op = primitive.impl(*tracers_in, **params) + + # if isinstance(cur_op, qml.measurements.MeasurementProcess): + # tracers = [TransformTracer(t._trace, t.val, t.idx + 1) for t in tracers] + # return primitive.bind(*tracers, **params) + + prev_op = previous_ops.pop(cur_op.wires[0], None) + if prev_op is None: + # No operator to compare against, so we save the current op and don't bind anything + for w in cur_op.wires: + previous_ops[w] = cur_op + state[0]["cancel_inverses_previous_ops"] = previous_ops + return [] + + cancel = False + if _are_inverses(cur_op, prev_op): + if cur_op.wires == prev_op.wires: + # Inverse ops, same wires. Cancel + cancel = True + if cur_op in symmetric_over_all_wires and len( + Wires.shared_wires([cur_op.wires, prev_op.wires]) + ) == len(cur_op.wires): + # Symmetric ops, full overlap in wires. Cancel + cancel = True + if cur_op in symmetric_over_control_wires and ( + len(Wires.shared_wires([cur_op.wires[:-1], prev_op.wires[:-1]])) + == len(cur_op.wires) - 1 + ): + # Is controlled op that is symmetric over control wires, and control + # wires have full overlap with prvious op's control wires + return + + if cancel: + # Update dictionary and don't bind any new primitives + for w in prev_op.wires: + previous_ops.pop(w) + state["cancel_inverses_previous_ops"] = previous_ops + return None + + # Can't cancel. Add current op to dictionary and + for w in prev_op.wires: + previous_ops.pop(w) + for w in cur_op.wires: + previous_ops[w] = cur_op + + # pylint: disable=protected-access + tracers = [TransformTracer(t._trace, t.val, t.idx + 1) for t in tracers] + # FIXME: This needs to bind the previous primitive, not the current one + return primitive.bind(*tracers, **params) + + +@partial(transform, plxpr_transform=_cancel_inverses_plxpr_transform) def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]: """Quantum function transform to remove any operations that are applied next to their (self-)inverses or adjoint.