diff --git a/crates/accelerate/src/basis/basis_translator/basis_search.rs b/crates/accelerate/src/basis/basis_translator/basis_search.rs index 5b87cd68f429..815d8d460c73 100644 --- a/crates/accelerate/src/basis/basis_translator/basis_search.rs +++ b/crates/accelerate/src/basis/basis_translator/basis_search.rs @@ -12,29 +12,33 @@ use std::cell::RefCell; -use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData}; -use ahash::RandomState; -use indexmap::{IndexMap, IndexSet}; +use hashbrown::{HashMap, HashSet}; use pyo3::prelude::*; -use pyo3::types::PySet; use qiskit_circuit::operations::{Operation, Param}; use rustworkx_core::petgraph::stable_graph::{EdgeReference, NodeIndex, StableDiGraph}; use rustworkx_core::petgraph::visit::Control; use rustworkx_core::traversal::{dijkstra_search, DijkstraEvent}; use smallvec::SmallVec; +use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData}; + #[pyfunction] #[pyo3(name = "basis_search")] pub(crate) fn py_basis_search( + py: Python, equiv_lib: &mut EquivalenceLibrary, - source_basis: &Bound, - target_basis: &Bound, -) -> PyResult { - let new_source_basis: PyResult> = - source_basis.iter().map(|item| item.extract()).collect(); - let new_target_basis: PyResult> = - target_basis.iter().map(|item| item.extract()).collect(); - Ok(basis_search(equiv_lib, new_source_basis?, new_target_basis?).into_py(source_basis.py())) + source_basis: HashSet<(String, u32)>, + target_basis: HashSet, +) -> PyObject { + basis_search( + equiv_lib, + source_basis + .iter() + .map(|(name, num_qubits)| (name.as_str(), *num_qubits)) + .collect(), + target_basis.iter().map(|name| name.as_str()).collect(), + ) + .into_py(py) } type BasisTransforms = Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>; @@ -51,24 +55,20 @@ type BasisTransforms = Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>; /// was found. pub(crate) fn basis_search( equiv_lib: &mut EquivalenceLibrary, - source_basis: IndexSet<(String, u32), RandomState>, - target_basis: IndexSet, + source_basis: HashSet<(&str, u32)>, + target_basis: HashSet<&str>, ) -> Option { // Build the visitor attributes: - let mut num_gates_remaining_for_rule: IndexMap = IndexMap::default(); - let predecessors: IndexMap<(&str, u32), Equivalence, RandomState> = IndexMap::default(); - let predecessors_cell: RefCell> = - RefCell::new(predecessors); - let opt_cost_map: IndexMap<(&str, u32), u32, RandomState> = IndexMap::default(); - let opt_cost_map_cell: RefCell> = - RefCell::new(opt_cost_map); + let mut num_gates_remaining_for_rule: HashMap = HashMap::default(); + let predecessors: RefCell> = RefCell::new(HashMap::default()); + let opt_cost_map: RefCell> = RefCell::new(HashMap::default()); let mut basis_transforms: Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)> = vec![]; // Initialize visitor attributes: initialize_num_gates_remain_for_rule(&equiv_lib.graph, &mut num_gates_remaining_for_rule); // TODO: Logs - let mut source_basis_remain: IndexSet = source_basis + let mut source_basis_remain: HashSet = source_basis .iter() .filter_map(|(gate_name, gate_num_qubits)| { if !target_basis.contains(gate_name) { @@ -92,7 +92,7 @@ pub(crate) fn basis_search( let target_basis_keys: Vec = equiv_lib .keys() .cloned() - .filter(|key| target_basis.contains(&key.name)) + .filter(|key| target_basis.contains(key.name.as_str())) .collect(); // Dummy node is inserted in the graph. Which is where the search will start @@ -123,7 +123,7 @@ pub(crate) fn basis_search( } let edge_data = edge.weight().as_ref().unwrap(); let mut cost_tot = 0; - let borrowed_cost = opt_cost_map_cell.borrow(); + let borrowed_cost = opt_cost_map.borrow(); for instruction in edge_data.rule.circuit.data.iter() { cost_tot += borrowed_cost[&(instruction.op.name(), instruction.op.num_qubits())]; } @@ -140,13 +140,13 @@ pub(crate) fn basis_search( DijkstraEvent::Discover(n, score) => { let gate_key = &equiv_lib.graph[n].key; let gate = &(gate_key.name.as_str(), gate_key.num_qubits); - source_basis_remain.swap_remove(gate_key); - let mut borrowed_cost_map = opt_cost_map_cell.borrow_mut(); + source_basis_remain.remove(gate_key); + let mut borrowed_cost_map = opt_cost_map.borrow_mut(); borrowed_cost_map .entry(*gate) .and_modify(|cost_ref| *cost_ref = score) .or_insert(score); - if let Some(rule) = predecessors_cell.borrow().get(gate) { + if let Some(rule) = predecessors.borrow().get(gate) { // TODO: Logger basis_transforms.push(( gate_key.name.to_string(), @@ -163,7 +163,7 @@ pub(crate) fn basis_search( } DijkstraEvent::EdgeRelaxed(_, target, Some(edata)) => { let gate = &equiv_lib.graph[target].key; - predecessors_cell + predecessors .borrow_mut() .entry((gate.name.as_str(), gate.num_qubits)) .and_modify(|value| *value = edata.rule.clone()) @@ -200,7 +200,7 @@ pub(crate) fn basis_search( fn initialize_num_gates_remain_for_rule( graph: &StableDiGraph>, - source: &mut IndexMap, + source: &mut HashMap, ) { let mut save_index = usize::MAX; for edge_data in graph.edge_weights().flatten() {