Skip to content

Commit

Permalink
Avoid Python op creation in BasisTranslator (Qiskit#12705)
Browse files Browse the repository at this point in the history
This commit updates the BasisTranslator transpiler pass. It builds off
of Qiskit#12692 and Qiskit#12701 to adjust access patterns in the python transpiler
path to avoid eagerly creating a Python space operation object. The goal
of this PR is to mitigate the performance regression introduced by the
extra conversion cost of Qiskit#12459 on the BasisTranslator.
  • Loading branch information
mtreinish authored and Procatv committed Aug 1, 2024
1 parent f37e8ee commit a276e0d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
5 changes: 5 additions & 0 deletions crates/circuit/src/dag_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ impl DAGOpNode {
self.instruction.params.to_object(py)
}

#[setter]
fn set_params(&mut self, val: smallvec::SmallVec<[crate::operations::Param; 3]>) {
self.instruction.params = val;
}

pub fn is_parameterized(&self) -> bool {
self.instruction.is_parameterized()
}
Expand Down
61 changes: 35 additions & 26 deletions qiskit/transpiler/passes/basis/basis_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
QuantumCircuit,
ParameterExpression,
)
from qiskit.dagcircuit import DAGCircuit
from qiskit.dagcircuit import DAGCircuit, DAGOpNode
from qiskit.converters import circuit_to_dag, dag_to_circuit
from qiskit.circuit.equivalence import Key, NodeData
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit._accelerate.circuit import StandardGate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -253,7 +255,7 @@ def apply_translation(dag, wire_map):
node_qargs = tuple(wire_map[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)
if node.name in target_basis or len(node.qargs) < self._min_qubits:
if isinstance(node.op, ControlFlowOp):
if node.name in CONTROL_FLOW_OP_NAMES:
flow_blocks = []
for block in node.op.blocks:
dag_block = circuit_to_dag(block)
Expand Down Expand Up @@ -281,7 +283,7 @@ def apply_translation(dag, wire_map):
continue
if qubit_set in extra_instr_map:
self._replace_node(dag, node, extra_instr_map[qubit_set])
elif (node.op.name, node.op.num_qubits) in instr_map:
elif (node.name, node.num_qubits) in instr_map:
self._replace_node(dag, node, instr_map)
else:
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
Expand All @@ -298,20 +300,29 @@ def apply_translation(dag, wire_map):
return dag

def _replace_node(self, dag, node, instr_map):
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
if len(node.op.params) != len(target_params):
target_params, target_dag = instr_map[node.name, node.num_qubits]
if len(node.params) != len(target_params):
raise TranspilerError(
"Translation num_params not equal to op num_params."
f"Op: {node.op.params} {node.op.name} Translation: {target_params}\n{target_dag}"
f"Op: {node.params} {node.name} Translation: {target_params}\n{target_dag}"
)
if node.op.params:
parameter_map = dict(zip(target_params, node.op.params))
if node.params:
parameter_map = dict(zip(target_params, node.params))
bound_target_dag = target_dag.copy_empty_like()
for inner_node in target_dag.topological_op_nodes():
if any(isinstance(x, ParameterExpression) for x in inner_node.op.params):
new_op = inner_node._raw_op
if not isinstance(inner_node._raw_op, StandardGate):
new_op = inner_node.op.copy()
new_node = DAGOpNode(
new_op,
qargs=inner_node.qargs,
cargs=inner_node.cargs,
params=inner_node.params,
dag=bound_target_dag,
)
if any(isinstance(x, ParameterExpression) for x in inner_node.params):
new_params = []
for param in new_op.params:
for param in new_node.params:
if not isinstance(param, ParameterExpression):
new_params.append(param)
else:
Expand All @@ -325,10 +336,10 @@ def _replace_node(self, dag, node, instr_map):
if not new_value.parameters:
new_value = new_value.numeric()
new_params.append(new_value)
new_op.params = new_params
else:
new_op = inner_node.op
bound_target_dag.apply_operation_back(new_op, inner_node.qargs, inner_node.cargs)
new_node.params = new_params
if not isinstance(new_op, StandardGate):
new_op.params = new_params
bound_target_dag._apply_op_node_back(new_node)
if isinstance(target_dag.global_phase, ParameterExpression):
old_phase = target_dag.global_phase
bind_dict = {x: parameter_map[x] for x in old_phase.parameters}
Expand All @@ -353,7 +364,7 @@ def _replace_node(self, dag, node, instr_map):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if getattr(node.op, "condition", None):
if getattr(node, "condition", None):
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)

Expand All @@ -370,8 +381,8 @@ def _extract_basis(self, circuit):
def _(self, dag: DAGCircuit):
for node in dag.op_nodes():
if not dag.has_calibration_for(node) and len(node.qargs) >= self._min_qubits:
yield (node.name, node.op.num_qubits)
if isinstance(node.op, ControlFlowOp):
yield (node.name, node.num_qubits)
if node.name in CONTROL_FLOW_OP_NAMES:
for block in node.op.blocks:
yield from self._extract_basis(block)

Expand Down Expand Up @@ -412,10 +423,10 @@ def _extract_basis_target(
frozenset(qargs).issuperset(incomplete_qargs)
for incomplete_qargs in self._qargs_with_non_global_operation
):
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.num_qubits))
else:
source_basis.add((node.name, node.op.num_qubits))
if isinstance(node.op, ControlFlowOp):
source_basis.add((node.name, node.num_qubits))
if node.name in CONTROL_FLOW_OP_NAMES:
for block in node.op.blocks:
block_dag = circuit_to_dag(block)
source_basis, qargs_local_source_basis = self._extract_basis_target(
Expand Down Expand Up @@ -628,7 +639,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
doomed_nodes = [
node
for node in dag.op_nodes()
if (node.op.name, node.op.num_qubits) == (gate_name, gate_num_qubits)
if (node.name, node.num_qubits) == (gate_name, gate_num_qubits)
]

if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
Expand All @@ -642,9 +653,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):

for node in doomed_nodes:

replacement = equiv.assign_parameters(
dict(zip_longest(equiv_params, node.op.params))
)
replacement = equiv.assign_parameters(dict(zip_longest(equiv_params, node.params)))

replacement_dag = circuit_to_dag(replacement)

Expand All @@ -666,8 +675,8 @@ def _get_example_gates(source_dag):
def recurse(dag, example_gates=None):
example_gates = example_gates or {}
for node in dag.op_nodes():
example_gates[(node.op.name, node.op.num_qubits)] = node.op
if isinstance(node.op, ControlFlowOp):
example_gates[(node.name, node.num_qubits)] = node
if node.name in CONTROL_FLOW_OP_NAMES:
for block in node.op.blocks:
example_gates = recurse(circuit_to_dag(block), example_gates)
return example_gates
Expand Down

0 comments on commit a276e0d

Please sign in to comment.