Skip to content

Commit

Permalink
Merge pull request #20 from mtreinish/collect_single_qubit_gate_runs_…
Browse files Browse the repository at this point in the history
…for_kevins_pr

Add implementation of collect_runs() and collect_1q_runs()
  • Loading branch information
kevinhartman authored Jul 9, 2024
2 parents d5a33f7 + 58c3d8f commit 1a535ff
Showing 1 changed file with 89 additions and 25 deletions.
114 changes: 89 additions & 25 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3042,34 +3042,59 @@ def _format(operand):
/// in the circuit's basis.
///
/// Nodes must have only one successor to continue the run.
fn collect_runs(&self, namelist: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
// def filter_fn(node):
// return (
// isinstance(node, DAGOpNode)
// and node.op.name in namelist
// and getattr(node.op, "condition", None) is None
// )
//
// group_list = rx.collect_runs(self._multi_graph, filter_fn)
// return {tuple(x) for x in group_list}
todo!()
#[pyo3(name = "collect_runs")]
fn py_collect_runs(&self, py: Python, namelist: &Bound<PyList>) -> PyResult<Py<PySet>> {
let mut name_list_set = HashSet::with_capacity(namelist.len());
for name in namelist.iter() {
name_list_set.insert(name.extract::<String>()?);
}
match self.collect_runs(name_list_set) {
Some(runs) => {
let run_iter = runs.map(|node_indices| {
PyTuple::new_bound(
py,
node_indices
.into_iter()
.map(|node_index| self.get_node(py, node_index).unwrap()),
)
.unbind()
});
let out_set = PySet::empty_bound(py)?;
for run_tuple in run_iter {
out_set.add(run_tuple)?;
}
Ok(out_set.unbind())
}
None => Err(PyRuntimeError::new_err(
"Invalid DAGCircuit, cycle encountered",
)),
}
}

/// Return a set of non-conditional runs of 1q "op" nodes.
fn collect_1q_runs(&self) -> PyResult<Py<PyList>> {
// def filter_fn(node):
// return (
// isinstance(node, DAGOpNode)
// and len(node.qargs) == 1
// and len(node.cargs) == 0
// and isinstance(node.op, Gate)
// and hasattr(node.op, "__array__")
// and getattr(node.op, "condition", None) is None
// and not node.op.is_parameterized()
// )
//
// return rx.collect_runs(self._multi_graph, filter_fn)
todo!()
#[pyo3(name = "collect_1q_runs")]
fn py_collect_1q_runs(&self, py: Python) -> PyResult<Py<PyList>> {
match self.collect_1q_runs() {
Some(runs) => {
let runs_iter = runs.map(|node_indices| {
PyList::new_bound(
py,
node_indices
.into_iter()
.map(|node_index| self.get_node(py, node_index).unwrap()),
)
.unbind()
});
let out_list = PyList::empty_bound(py);
for run_list in runs_iter {
out_list.append(run_list)?;
}
Ok(out_list.unbind())
}
None => Err(PyRuntimeError::new_err(
"Invalid DAGCircuit, cycle encountered",
)),
}
}

/// Return a set of non-conditional runs of 2q "op" nodes.
Expand Down Expand Up @@ -3319,6 +3344,45 @@ def _format(operand):
}

impl DAGCircuit {
/// Return an iterator of gate runs with non-conditional op nodes of given names
pub fn collect_runs(
&self,
namelist: HashSet<String>,
) -> Option<impl Iterator<Item = Vec<NodeIndex>> + '_> {
let filter_fn = move |node_index: NodeIndex| -> Result<bool, Infallible> {
let node = &self.dag[node_index];
match node {
NodeType::Operation(inst) => Ok(namelist.contains(inst.op.name())
&& match &inst.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}),
_ => Ok(false),
}
};
rustworkx_core::dag_algo::collect_runs(&self.dag, filter_fn)
.map(|node_iter| node_iter.map(|x| x.unwrap()))
}

/// Return a set of non-conditional runs of 1q "op" nodes.
pub fn collect_1q_runs(&self) -> Option<impl Iterator<Item = Vec<NodeIndex>> + '_> {
let filter_fn = move |node_index: NodeIndex| -> Result<bool, Infallible> {
let node = &self.dag[node_index];
match node {
NodeType::Operation(inst) => Ok(inst.op.num_qubits() == 1
&& inst.op.num_clbits() == 0
&& inst.op.matrix(&inst.params).is_some()
&& match &inst.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}),
_ => Ok(false),
}
};
rustworkx_core::dag_algo::collect_runs(&self.dag, filter_fn)
.map(|node_iter| node_iter.map(|x| x.unwrap()))
}

fn increment_op(&mut self, op: String) {
match self.op_names.entry(op) {
hash_map::Entry::Occupied(mut o) => {
Expand Down

0 comments on commit 1a535ff

Please sign in to comment.