Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port RemoveDiagonalGatesBeforeMeasure to rust #13065

Merged
merged 17 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod isometry;
pub mod nlayout;
pub mod optimize_1q_gates;
pub mod pauli_exp_val;
pub mod remove_diagonal_gates_before_measure;
pub mod results;
pub mod sabre;
pub mod sampled_exp_val;
Expand Down
103 changes: 103 additions & 0 deletions crates/accelerate/src/remove_diagonal_gates_before_measure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

/// Remove diagonal gates (including diagonal 2Q gates) before a measurement.
use pyo3::prelude::*;
use std::collections::HashSet;

use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType};
use qiskit_circuit::operations::Operation;
use qiskit_circuit::operations::StandardGate;

/// Run the RemoveDiagonalGatesBeforeMeasure pass on `dag`.
/// Args:
/// dag (DAGCircuit): the DAG to be optimized.
/// Returns:
/// DAGCircuit: the optimized DAG.
#[pyfunction]
#[pyo3(name = "remove_diagonal_gates_before_measure")]
fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
let diagonal_1q_gates = HashSet::from([
ShellyGarion marked this conversation as resolved.
Show resolved Hide resolved
StandardGate::RZGate,
StandardGate::ZGate,
StandardGate::TGate,
StandardGate::SGate,
StandardGate::TdgGate,
StandardGate::SdgGate,
StandardGate::U1Gate,
StandardGate::PhaseGate,
]);
let diagonal_2q_gates = HashSet::from([
StandardGate::CZGate,
StandardGate::CRZGate,
StandardGate::CU1Gate,
StandardGate::RZZGate,
StandardGate::CPhaseGate,
StandardGate::CSGate,
StandardGate::CSdgGate,
]);

Copy link
Member Author

@ShellyGarion ShellyGarion Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there are some 1-qubit and 2-qubit diagonal gates that did not appear in the original list, so I added them here (I'll add them to the tests later).

There is also a 3-qubit diagonal gate: CCZGate (which was not added in this PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, there is also an n-qubit diagonal gate: MCPhaseGate. This gate was not handled this PR.
This is since the algorithm given here that goes over all the successors of each of the predecessors would be O(n^2) and not O(n).

let mut nodes_to_remove = HashSet::new();
ShellyGarion marked this conversation as resolved.
Show resolved Hide resolved
for index in dag.op_nodes(true) {
let node = &dag.dag[index];
let NodeType::Operation(inst) = node else {panic!()};

if inst.op.name() == "measure" {
let predecessor = (dag.quantum_predecessors(index))
.next()
.expect("index is an operation node, so it must have a predecessor.");

match &dag.dag[predecessor] {
NodeType::Operation(pred_inst) => match pred_inst.standard_gate() {
Some(gate) => {
if diagonal_1q_gates.contains(&gate) {
nodes_to_remove.insert(predecessor);
} else if diagonal_2q_gates.contains(&gate) {
let successors = dag.quantum_successors(predecessor);
let remove_s = successors
.map(|s| {
let node_s = &dag.dag[s];
if let NodeType::Operation(inst_s) = node_s {
inst_s.op.name() == "measure"
} else {
false
}
})
.all(|ok_to_remove| ok_to_remove);
if remove_s {
nodes_to_remove.insert(predecessor);
}
}
}
None => {
continue;
}
},
_ => {
continue;
}
}
}
}

for node_to_remove in nodes_to_remove {
dag.remove_op_node(node_to_remove)
}

Ok(())
}

#[pymodule]
pub fn remove_diagonal_gates_before_measure(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(run_remove_diagonal_before_measure))?;
Ok(())
}
4 changes: 2 additions & 2 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5115,7 +5115,7 @@ impl DAGCircuit {
}
}

fn quantum_predecessors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
pub fn quantum_predecessors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.dag
.edges_directed(node, Incoming)
.filter_map(|e| match e.weight() {
Expand All @@ -5125,7 +5125,7 @@ impl DAGCircuit {
.unique()
}

fn quantum_successors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
pub fn quantum_successors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.dag
.edges_directed(node, Outgoing)
.filter_map(|e| match e.weight() {
Expand Down
8 changes: 7 additions & 1 deletion crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use qiskit_accelerate::{
circuit_library::circuit_library, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, results::results,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval,
remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
Expand Down Expand Up @@ -49,6 +50,11 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, optimize_1q_gates, "optimize_1q_gates")?;
add_submodule(m, pauli_expval, "pauli_expval")?;
add_submodule(m, synthesis, "synthesis")?;
add_submodule(
m,
remove_diagonal_gates_before_measure,
"remove_diagonal_gates_before_measure",
)?;
add_submodule(m, results, "results")?;
add_submodule(m, sabre, "sabre")?;
add_submodule(m, sampled_exp_val, "sampled_exp_val")?;
Expand Down
3 changes: 3 additions & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
sys.modules["qiskit._accelerate.pauli_expval"] = _accelerate.pauli_expval
sys.modules["qiskit._accelerate.qasm2"] = _accelerate.qasm2
sys.modules["qiskit._accelerate.qasm3"] = _accelerate.qasm3
sys.modules["qiskit._accelerate.remove_diagonal_gates_before_measure"] = (
_accelerate.remove_diagonal_gates_before_measure
)
sys.modules["qiskit._accelerate.results"] = _accelerate.results
sys.modules["qiskit._accelerate.sabre"] = _accelerate.sabre
sys.modules["qiskit._accelerate.sampled_exp_val"] = _accelerate.sampled_exp_val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,13 @@

"""Remove diagonal gates (including diagonal 2Q gates) before a measurement."""

from qiskit.circuit import Measure
from qiskit.circuit.library.standard_gates import (
RZGate,
ZGate,
TGate,
SGate,
TdgGate,
SdgGate,
U1Gate,
CZGate,
CRZGate,
CU1Gate,
RZZGate,
)
from qiskit.dagcircuit import DAGOpNode
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.passes.utils import control_flow

from qiskit._accelerate.remove_diagonal_gates_before_measure import (
remove_diagonal_gates_before_measure,
)


class RemoveDiagonalGatesBeforeMeasure(TransformationPass):
"""Remove diagonal gates (including diagonal 2Q gates) before a measurement.
Expand All @@ -48,22 +37,5 @@ def run(self, dag):
Returns:
DAGCircuit: the optimized DAG.
"""
diagonal_1q_gates = (RZGate, ZGate, TGate, SGate, TdgGate, SdgGate, U1Gate)
diagonal_2q_gates = (CZGate, CRZGate, CU1Gate, RZZGate)

nodes_to_remove = set()
for measure in dag.op_nodes(Measure):
predecessor = next(dag.quantum_predecessors(measure))

if isinstance(predecessor, DAGOpNode) and isinstance(predecessor.op, diagonal_1q_gates):
nodes_to_remove.add(predecessor)

if isinstance(predecessor, DAGOpNode) and isinstance(predecessor.op, diagonal_2q_gates):
successors = dag.quantum_successors(predecessor)
if all(isinstance(s, DAGOpNode) and isinstance(s.op, Measure) for s in successors):
nodes_to_remove.add(predecessor)

for node_to_remove in nodes_to_remove:
dag.remove_op_node(node_to_remove)

remove_diagonal_gates_before_measure(dag)
ShellyGarion marked this conversation as resolved.
Show resolved Hide resolved
return dag
Loading