diff --git a/qiskit/dagcircuit/dagcircuit.py b/qiskit/dagcircuit/dagcircuit.py index 6bec21bb9db6..c5a97c09113f 100644 --- a/qiskit/dagcircuit/dagcircuit.py +++ b/qiskit/dagcircuit/dagcircuit.py @@ -866,60 +866,6 @@ def _check_wires_list(self, wires, node): if len(wires) != wire_tot: raise DAGCircuitError("expected %d wires, got %d" % (wire_tot, len(wires))) - def _make_pred_succ_maps(self, node): - """Return predecessor and successor dictionaries. - - Args: - node (DAGOpNode): reference to multi_graph node - - Returns: - tuple(dict): tuple(predecessor_map, successor_map) - These map from wire (Register, int) to the node ids for the - predecessor (successor) nodes of the input node. - """ - - pred_map = {e[2]: e[0] for e in self._multi_graph.in_edges(node._node_id)} - succ_map = {e[2]: e[1] for e in self._multi_graph.out_edges(node._node_id)} - return pred_map, succ_map - - def _full_pred_succ_maps(self, pred_map, succ_map, input_circuit, wire_map): - """Map all wires of the input circuit. - - Map all wires of the input circuit to predecessor and - successor nodes in self, keyed on wires in self. - - Args: - pred_map (dict): comes from _make_pred_succ_maps - succ_map (dict): comes from _make_pred_succ_maps - input_circuit (DAGCircuit): the input circuit - wire_map (dict): the map from wires of input_circuit to wires of self - - Returns: - tuple: full_pred_map, full_succ_map (dict, dict) - - Raises: - DAGCircuitError: if more than one predecessor for output nodes - """ - full_pred_map = {} - full_succ_map = {} - for w in input_circuit.input_map: - # If w is wire mapped, find the corresponding predecessor - # of the node - if w in wire_map: - full_pred_map[wire_map[w]] = pred_map[wire_map[w]] - full_succ_map[wire_map[w]] = succ_map[wire_map[w]] - else: - # Otherwise, use the corresponding output nodes of self - # and compute the predecessor. - full_succ_map[w] = self.output_map[w] - full_pred_map[w] = self._multi_graph.predecessors(self.output_map[w])[0] - if len(self._multi_graph.predecessors(self.output_map[w])) != 1: - raise DAGCircuitError( - "too many predecessors for %s[%d] output node" % (w.register, w.index) - ) - - return full_pred_map, full_succ_map - def __eq__(self, other): # Try to convert to float, but in case of unbound ParameterExpressions # a TypeError will be raise, fallback to normal equality in those @@ -1022,7 +968,7 @@ def substitute_node_with_dag(self, node, input_dag, wires=None): if wires is None: wires = in_dag.wires - + wire_set = set(wires) self._check_wires_list(wires, node) # Create a proxy wire_map to identify fragments and duplicates @@ -1044,12 +990,14 @@ def substitute_node_with_dag(self, node, input_dag, wires=None): condition_bit_list = self._bits_in_condition(node.op.condition) - wire_map = dict(zip(wires, list(node.qargs) + list(node.cargs) + list(condition_bit_list))) + new_wires = list(node.qargs) + list(node.cargs) + list(condition_bit_list) + + wire_map = {} + reverse_wire_map = {} + for wire, new_wire in zip(wires, new_wires): + wire_map[wire] = new_wire + reverse_wire_map[new_wire] = wire self._check_wiremap_validity(wire_map, wires, self.input_map) - pred_map, succ_map = self._make_pred_succ_maps(node) - full_pred_map, full_succ_map = self._full_pred_succ_maps( - pred_map, succ_map, in_dag, wire_map - ) if condition_bit_list: # If we are replacing a conditional node, map input dag through @@ -1065,40 +1013,79 @@ def substitute_node_with_dag(self, node, input_dag, wires=None): "Mapped DAG would alter clbits on which it would be conditioned." ) - # Now that we know the connections, delete node - self._multi_graph.remove_node(node._node_id) - - # Iterate over nodes of input_circuit - for sorted_node in in_dag.topological_op_nodes(): - # Insert a new node - condition = self._map_condition(wire_map, sorted_node.op.condition, self.cregs.values()) - m_qargs = list(map(lambda x: wire_map.get(x, x), sorted_node.qargs)) - m_cargs = list(map(lambda x: wire_map.get(x, x), sorted_node.cargs)) - node_index = self._add_op_node(sorted_node.op, m_qargs, m_cargs) - - # Add edges from predecessor nodes to new node - # and update predecessor nodes that change - all_cbits = self._bits_in_condition(condition) - all_cbits.extend(m_cargs) - al = [m_qargs, all_cbits] - for q in itertools.chain(*al): - self._multi_graph.add_edge(full_pred_map[q], node_index, q) - full_pred_map[q] = node_index - - # Connect all predecessors and successors, and remove - # residual edges between input and output nodes - for w in full_pred_map: - self._multi_graph.add_edge(full_pred_map[w], full_succ_map[w], w) - o_pred = self._multi_graph.predecessors(self.output_map[w]._node_id) - if len(o_pred) > 1: - if len(o_pred) != 2: - raise DAGCircuitError("expected 2 predecessors here") - - p = [x for x in o_pred if x != full_pred_map[w]] - if len(p) != 1: - raise DAGCircuitError("expected 1 predecessor to pass filter") - - self._multi_graph.remove_edge(p[0], self.output_map[w]) + # Add wire from pred to succ if no ops on mapped wire on ``in_dag`` + # retworkx's substitute_node_with_subgraph lacks the DAGCircuit + # context to know what to do in this case (the method won't even see + # these nodes because they're filtered) so we manually retain the + # edges prior to calling substitute_node_with_subgraph and set the + # edge_map_fn callback kwarg to skip these edges when they're + # encountered. + for wire in wires: + input_node = in_dag.input_map[wire] + output_node = in_dag.output_map[wire] + if in_dag._multi_graph.has_edge(input_node._node_id, output_node._node_id): + self_wire = wire_map[wire] + pred = self._multi_graph.find_predecessors_by_edge( + node._node_id, lambda edge, wire=self_wire: edge == wire + )[0] + succ = self._multi_graph.find_successors_by_edge( + node._node_id, lambda edge, wire=self_wire: edge == wire + )[0] + self._multi_graph.add_edge(pred._node_id, succ._node_id, self_wire) + + # Exlude any nodes from in_dag that are not a DAGOpNode or are on + # bits outside the set specified by the wires kwarg + def filter_fn(node): + if not isinstance(node, DAGOpNode): + return False + for qarg in node.qargs: + if qarg not in wire_set: + return False + return True + + # Map edges into and out of node to the appropriate node from in_dag + def edge_map_fn(source, _target, self_wire): + wire = reverse_wire_map[self_wire] + # successor edge + if source == node._node_id: + wire_output_id = in_dag.output_map[wire]._node_id + out_index = in_dag._multi_graph.predecessor_indices(wire_output_id)[0] + # Edge directly from from input nodes to output nodes in in_dag are + # already handled prior to calling retworkx. Don't map these edges + # in retworkx. + if not isinstance(in_dag._multi_graph[out_index], DAGOpNode): + return None + # predecessor edge + else: + wire_input_id = in_dag.input_map[wire]._node_id + out_index = in_dag._multi_graph.successor_indices(wire_input_id)[0] + # Edge directly from from input nodes to output nodes in in_dag are + # already handled prior to calling retworkx. Don't map these edges + # in retworkx. + if not isinstance(in_dag._multi_graph[out_index], DAGOpNode): + return None + return out_index + + # Adjust edge weights from in_dag + def edge_weight_map(wire): + return wire_map[wire] + + node_map = self._multi_graph.substitute_node_with_subgraph( + node._node_id, in_dag._multi_graph, edge_map_fn, filter_fn, edge_weight_map + ) + + # Iterate over nodes of input_circuit and update wiires in node objects migrated + # from in_dag + for old_node_index, new_node_index in node_map.items(): + # update node attributes + old_node = in_dag._multi_graph[old_node_index] + condition = self._map_condition(wire_map, old_node.op.condition, self.cregs.values()) + m_qargs = [wire_map.get(x, x) for x in old_node.qargs] + m_cargs = [wire_map.get(x, x) for x in old_node.cargs] + new_node = DAGOpNode(old_node.op, qargs=m_qargs, cargs=m_cargs) + new_node._node_id = new_node_index + new_node.op.condition = condition + self._multi_graph[new_node_index] = new_node def substitute_node(self, node, op, inplace=False): """Replace an DAGOpNode with a single instruction. qargs, cargs and diff --git a/releasenotes/notes/retworkx-substitute_node_with_dag-speedup-d7d1f0d33716131d.yaml b/releasenotes/notes/retworkx-substitute_node_with_dag-speedup-d7d1f0d33716131d.yaml new file mode 100644 index 000000000000..bac8c876fd6a --- /dev/null +++ b/releasenotes/notes/retworkx-substitute_node_with_dag-speedup-d7d1f0d33716131d.yaml @@ -0,0 +1,9 @@ +--- +features: + - | + Various transpilation internals now use new features in `retworkx + `__ 0.10 when operating on the internal + circuit representation. This can often result in speedups in calls to + :obj:`~qiskit.transpile` of around 10-40%, with greater effects at higher + optimisation levels. See `#6302 + `__ for more details. diff --git a/test/python/dagcircuit/test_dagcircuit.py b/test/python/dagcircuit/test_dagcircuit.py index 5cbf2fe739fa..178c0ef74189 100644 --- a/test/python/dagcircuit/test_dagcircuit.py +++ b/test/python/dagcircuit/test_dagcircuit.py @@ -1186,14 +1186,60 @@ def test_substitute_circuit_one_middle(self): self.dag.substitute_node_with_dag(cx_node, flipped_cx_circuit, wires=[v[0], v[1]]) self.assertEqual(self.dag.count_ops()["h"], 5) + expected = DAGCircuit() + qreg = QuantumRegister(3, "qr") + creg = ClassicalRegister(2, "cr") + expected.add_qreg(qreg) + expected.add_creg(creg) + expected.apply_operation_back(HGate(), [qreg[0]], []) + expected.apply_operation_back(HGate(), [qreg[0]], []) + expected.apply_operation_back(HGate(), [qreg[1]], []) + expected.apply_operation_back(CXGate(), [qreg[1], qreg[0]], []) + expected.apply_operation_back(HGate(), [qreg[0]], []) + expected.apply_operation_back(HGate(), [qreg[1]], []) + expected.apply_operation_back(XGate(), [qreg[1]], []) + self.assertEqual(self.dag, expected) def test_substitute_circuit_one_front(self): """The method substitute_node_with_dag() replaces a leaf-in-the-front node with a DAG.""" - pass + circuit = DAGCircuit() + v = QuantumRegister(1, "v") + circuit.add_qreg(v) + circuit.apply_operation_back(HGate(), [v[0]], []) + circuit.apply_operation_back(XGate(), [v[0]], []) + + self.dag.substitute_node_with_dag(next(self.dag.topological_op_nodes()), circuit) + expected = DAGCircuit() + qreg = QuantumRegister(3, "qr") + creg = ClassicalRegister(2, "cr") + expected.add_qreg(qreg) + expected.add_creg(creg) + expected.apply_operation_back(HGate(), [qreg[0]], []) + expected.apply_operation_back(XGate(), [qreg[0]], []) + expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], []) + expected.apply_operation_back(XGate(), [qreg[1]], []) + self.assertEqual(self.dag, expected) def test_substitute_circuit_one_back(self): """The method substitute_node_with_dag() replaces a leaf-in-the-back node with a DAG.""" - pass + circuit = DAGCircuit() + v = QuantumRegister(1, "v") + circuit.add_qreg(v) + circuit.apply_operation_back(HGate(), [v[0]], []) + circuit.apply_operation_back(XGate(), [v[0]], []) + + self.dag.substitute_node_with_dag(list(self.dag.topological_op_nodes())[2], circuit) + expected = DAGCircuit() + qreg = QuantumRegister(3, "qr") + creg = ClassicalRegister(2, "cr") + expected.add_qreg(qreg) + expected.add_creg(creg) + expected.apply_operation_back(HGate(), [qreg[0]], []) + expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], []) + expected.apply_operation_back(HGate(), [qreg[1]], []) + expected.apply_operation_back(XGate(), [qreg[1]], []) + + self.assertEqual(self.dag, expected) def test_raise_if_substituting_dag_modifies_its_conditional(self): """Verify that we raise if the input dag modifies any of the bits in node.op.condition.""" diff --git a/test/python/transpiler/test_basis_translator.py b/test/python/transpiler/test_basis_translator.py index 540df7d20ae9..6236438a4d13 100644 --- a/test/python/transpiler/test_basis_translator.py +++ b/test/python/transpiler/test_basis_translator.py @@ -12,6 +12,7 @@ """Test the BasisTranslator pass""" + import os from numpy import pi