Skip to content

Commit

Permalink
Speedup the CommutationChecker and the CommutativeCancellation (#…
Browse files Browse the repository at this point in the history
…12859)

* Faster commutation checker and analysis

- list of pre-approved gates we know we support commutation on
- less redirections in function calls
- commutation analysis to only trigger search on gates that are actually cancelled

* cleanup comments

* add reno

* review comments

* revert accidentially changed tests

-- these need updating only on the version with parameter support

* revert changes in test_comm_inv_canc

---------

Co-authored-by: MarcDrudis <MarcSanzDrudis@outlook.com>
  • Loading branch information
Cryoris and MarcDrudis authored Aug 1, 2024
1 parent 0afb06e commit 441925e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 51 deletions.
98 changes: 57 additions & 41 deletions qiskit/circuit/commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""Code from commutative_analysis pass that checks commutation relations between DAG nodes."""

from functools import lru_cache
from typing import List, Union
from typing import List, Union, Set, Optional
import numpy as np

from qiskit import QiskitError
Expand All @@ -25,6 +25,27 @@
_skipped_op_names = {"measure", "reset", "delay", "initialize"}
_no_cache_op_names = {"annotated"}

_supported_ops = {
"h",
"x",
"y",
"z",
"sx",
"sxdg",
"t",
"tdg",
"s",
"sdg",
"cx",
"cy",
"cz",
"swap",
"iswap",
"ecr",
"ccx",
"cswap",
}


@lru_cache(maxsize=None)
def _identity_op(num_qubits):
Expand All @@ -42,7 +63,13 @@ class CommutationChecker:
evicting from the cache less useful entries, etc.
"""

def __init__(self, standard_gate_commutations: dict = None, cache_max_entries: int = 10**6):
def __init__(
self,
standard_gate_commutations: dict = None,
cache_max_entries: int = 10**6,
*,
gates: Optional[Set[str]] = None,
):
super().__init__()
if standard_gate_commutations is None:
self._standard_commutations = {}
Expand All @@ -56,6 +83,7 @@ def __init__(self, standard_gate_commutations: dict = None, cache_max_entries: i
self._current_cache_entries = 0
self._cache_miss = 0
self._cache_hit = 0
self._gate_names = gates

def commute_nodes(
self,
Expand Down Expand Up @@ -103,6 +131,11 @@ def commute(
Returns:
bool: whether two operations commute.
"""
# Skip gates that are not specified.
if self._gate_names is not None:
if op1.name not in self._gate_names or op2.name not in self._gate_names:
return False

structural_commutation = _commutation_precheck(
op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits
)
Expand Down Expand Up @@ -231,59 +264,38 @@ def _hashable_parameters(params):
return ("fallback", str(params))


def is_commutation_supported(op):
def is_commutation_supported(op, qargs, max_num_qubits):
"""
Filter operations whose commutation is not supported due to bugs in transpiler passes invoking
commutation analysis.
Args:
op (Operation): operation to be checked for commutation relation
op (Operation): operation to be checked for commutation relation.
qargs (list[Qubit]): qubits the operation acts on.
max_num_qubits (int): The maximum number of qubits to check commutativity for.
Return:
True if determining the commutation of op is currently supported
"""
# Bug in CommutativeCancellation, e.g. see gh-8553
if getattr(op, "condition", False):
# If the number of qubits is beyond what we check, stop here and do not even check in the
# pre-defined supported operations
if len(qargs) > max_num_qubits:
return False

# Check if the operation is pre-approved, otherwise go through the checks
if op.name in _supported_ops:
return True

# Commutation of ControlFlow gates also not supported yet. This may be pending a control flow graph.
if op.name in CONTROL_FLOW_OP_NAMES:
return False

return True


def is_commutation_skipped(op, qargs, max_num_qubits):
"""
Filter operations whose commutation will not be determined.
Args:
op (Operation): operation to be checked for commutation relation
qargs (List): operation qubits
max_num_qubits (int): the maximum number of qubits to consider, the check may be skipped if
the number of qubits for either operation exceeds this amount.
Return:
True if determining the commutation of op is currently not supported
"""
if (
len(qargs) > max_num_qubits
or getattr(op, "_directive", False)
or op.name in _skipped_op_names
):
return True
if getattr(op, "_directive", False) or op.name in _skipped_op_names:
return False

if getattr(op, "is_parameterized", False) and op.is_parameterized():
return True

from qiskit.dagcircuit.dagnode import DAGOpNode

# we can proceed if op has defined: to_operator, to_matrix and __array__, or if its definition can be
# recursively resolved by operations that have a matrix. We check this by constructing an Operator.
if (
isinstance(op, DAGOpNode)
or (hasattr(op, "to_matrix") and hasattr(op, "__array__"))
or hasattr(op, "to_operator")
):
return False

return False
return True


def _commutation_precheck(
Expand All @@ -295,13 +307,14 @@ def _commutation_precheck(
cargs2: List,
max_num_qubits,
):
if not is_commutation_supported(op1) or not is_commutation_supported(op2):
# Bug in CommutativeCancellation, e.g. see gh-8553
if getattr(op1, "condition", False) or getattr(op2, "condition", False):
return False

if set(qargs1).isdisjoint(qargs2) and set(cargs1).isdisjoint(cargs2):
return True

if is_commutation_skipped(op1, qargs1, max_num_qubits) or is_commutation_skipped(
if not is_commutation_supported(op1, qargs1, max_num_qubits) or not is_commutation_supported(
op2, qargs2, max_num_qubits
):
return False
Expand Down Expand Up @@ -409,7 +422,10 @@ def _query_commutation(
first_params = getattr(first_op, "params", [])
second_params = getattr(second_op, "params", [])
return commutation_after_placement.get(
(_hashable_parameters(first_params), _hashable_parameters(second_params)),
(
_hashable_parameters(first_params),
_hashable_parameters(second_params),
),
None,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion qiskit/circuit/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def is_parameterized(self):
"""Return whether the :class:`Instruction` contains :ref:`compile-time parameters
<circuit-compile-time-parameters>`."""
return any(
isinstance(param, ParameterExpression) and param.parameters for param in self.params
isinstance(param, ParameterExpression) and param.parameters for param in self._params
)

@property
Expand Down
8 changes: 6 additions & 2 deletions qiskit/transpiler/passes/optimization/commutation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ class CommutationAnalysis(AnalysisPass):
are grouped into a set of gates that commute.
"""

def __init__(self):
def __init__(self, *, _commutation_checker=None):
super().__init__()
self.comm_checker = scc
# allow setting a private commutation checker, this allows better performance if we
# do not care about commutations of all gates, but just a subset
if _commutation_checker is None:
_commutation_checker = scc
self.comm_checker = _commutation_checker

def run(self, dag):
"""Run the CommutationAnalysis pass on `dag`.
Expand Down
26 changes: 19 additions & 7 deletions qiskit/transpiler/passes/optimization/commutative_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import numpy as np

from qiskit.circuit.quantumregister import QuantumRegister
from qiskit.circuit.parameterexpression import ParameterExpression
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.passmanager import PassManager
from qiskit.transpiler.passes.optimization.commutation_analysis import CommutationAnalysis
from qiskit.dagcircuit import DAGCircuit, DAGInNode, DAGOutNode
from qiskit.circuit.commutation_library import CommutationChecker, StandardGateCommutations
from qiskit.circuit.library.standard_gates.u1 import U1Gate
from qiskit.circuit.library.standard_gates.rx import RXGate
from qiskit.circuit.library.standard_gates.p import PhaseGate
Expand Down Expand Up @@ -61,7 +63,18 @@ def __init__(self, basis_gates=None, target=None):
self.basis = set(target.operation_names)

self._var_z_map = {"rz": RZGate, "p": PhaseGate, "u1": U1Gate}
self.requires.append(CommutationAnalysis())

self._z_rotations = {"p", "z", "u1", "rz", "t", "s"}
self._x_rotations = {"x", "rx"}
self._gates = {"cx", "cy", "cz", "h", "y"} # Now the gates supported are hard-coded

# build a commutation checker restricted to the gates we cancel -- the others we
# do not have to investigate, which allows to save time
commutation_checker = CommutationChecker(
StandardGateCommutations, gates=self._gates | self._z_rotations | self._x_rotations
)

self.requires.append(CommutationAnalysis(_commutation_checker=commutation_checker))

def run(self, dag):
"""Run the CommutativeCancellation pass on `dag`.
Expand All @@ -82,9 +95,6 @@ def run(self, dag):
if z_var_gates:
var_z_gate = self._var_z_map[next(iter(z_var_gates))]

# Now the gates supported are hard-coded
q_gate_list = ["cx", "cy", "cz", "h", "y"]

# Gate sets to be cancelled
cancellation_sets = defaultdict(lambda: [])

Expand All @@ -103,9 +113,11 @@ def run(self, dag):
continue
for node in com_set:
num_qargs = len(node.qargs)
if num_qargs == 1 and node.name in q_gate_list:
if any(isinstance(p, ParameterExpression) for p in node.params):
continue # no support for cancellation of parameterized gates
if num_qargs == 1 and node.name in self._gates:
cancellation_sets[(node.name, wire, com_set_idx)].append(node)
if num_qargs == 1 and node.name in ["p", "z", "u1", "rz", "t", "s"]:
if num_qargs == 1 and node.name in self._z_rotations:
cancellation_sets[("z_rotation", wire, com_set_idx)].append(node)
if num_qargs == 1 and node.name in ["rx", "x"]:
cancellation_sets[("x_rotation", wire, com_set_idx)].append(node)
Expand All @@ -126,7 +138,7 @@ def run(self, dag):
if cancel_set_key[0] == "z_rotation" and var_z_gate is None:
continue
set_len = len(cancellation_sets[cancel_set_key])
if set_len > 1 and cancel_set_key[0] in q_gate_list:
if set_len > 1 and cancel_set_key[0] in self._gates:
gates_to_cancel = cancellation_sets[cancel_set_key]
for c_node in gates_to_cancel[: (set_len // 2) * 2]:
dag.remove_op_node(c_node)
Expand Down

0 comments on commit 441925e

Please sign in to comment.