Skip to content

Commit

Permalink
Fix: Remove IndexMap and duplicate declarations.
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Aug 29, 2024
1 parent 3f2f64c commit b7a40d7
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions crates/accelerate/src/basis/basis_translator/basis_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,33 @@

use std::cell::RefCell;

use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData};
use ahash::RandomState;
use indexmap::{IndexMap, IndexSet};
use hashbrown::{HashMap, HashSet};
use pyo3::prelude::*;
use pyo3::types::PySet;
use qiskit_circuit::operations::{Operation, Param};
use rustworkx_core::petgraph::stable_graph::{EdgeReference, NodeIndex, StableDiGraph};
use rustworkx_core::petgraph::visit::Control;
use rustworkx_core::traversal::{dijkstra_search, DijkstraEvent};
use smallvec::SmallVec;

use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData};

#[pyfunction]
#[pyo3(name = "basis_search")]
pub(crate) fn py_basis_search(
py: Python,
equiv_lib: &mut EquivalenceLibrary,
source_basis: &Bound<PySet>,
target_basis: &Bound<PySet>,
) -> PyResult<PyObject> {
let new_source_basis: PyResult<IndexSet<(String, u32), RandomState>> =
source_basis.iter().map(|item| item.extract()).collect();
let new_target_basis: PyResult<IndexSet<String, RandomState>> =
target_basis.iter().map(|item| item.extract()).collect();
Ok(basis_search(equiv_lib, new_source_basis?, new_target_basis?).into_py(source_basis.py()))
source_basis: HashSet<(String, u32)>,
target_basis: HashSet<String>,
) -> PyObject {
basis_search(
equiv_lib,
source_basis
.iter()
.map(|(name, num_qubits)| (name.as_str(), *num_qubits))
.collect(),
target_basis.iter().map(|name| name.as_str()).collect(),
)
.into_py(py)
}

type BasisTransforms = Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>;
Expand All @@ -51,24 +55,20 @@ type BasisTransforms = Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>;
/// was found.
pub(crate) fn basis_search(
equiv_lib: &mut EquivalenceLibrary,
source_basis: IndexSet<(String, u32), RandomState>,
target_basis: IndexSet<String, RandomState>,
source_basis: HashSet<(&str, u32)>,
target_basis: HashSet<&str>,
) -> Option<BasisTransforms> {
// Build the visitor attributes:
let mut num_gates_remaining_for_rule: IndexMap<usize, usize, RandomState> = IndexMap::default();
let predecessors: IndexMap<(&str, u32), Equivalence, RandomState> = IndexMap::default();
let predecessors_cell: RefCell<IndexMap<(&str, u32), Equivalence, RandomState>> =
RefCell::new(predecessors);
let opt_cost_map: IndexMap<(&str, u32), u32, RandomState> = IndexMap::default();
let opt_cost_map_cell: RefCell<IndexMap<(&str, u32), u32, RandomState>> =
RefCell::new(opt_cost_map);
let mut num_gates_remaining_for_rule: HashMap<usize, usize> = HashMap::default();
let predecessors: RefCell<HashMap<(&str, u32), Equivalence>> = RefCell::new(HashMap::default());
let opt_cost_map: RefCell<HashMap<(&str, u32), u32>> = RefCell::new(HashMap::default());
let mut basis_transforms: Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)> = vec![];

// Initialize visitor attributes:
initialize_num_gates_remain_for_rule(&equiv_lib.graph, &mut num_gates_remaining_for_rule);

// TODO: Logs
let mut source_basis_remain: IndexSet<Key> = source_basis
let mut source_basis_remain: HashSet<Key> = source_basis
.iter()
.filter_map(|(gate_name, gate_num_qubits)| {
if !target_basis.contains(gate_name) {
Expand All @@ -92,7 +92,7 @@ pub(crate) fn basis_search(
let target_basis_keys: Vec<Key> = equiv_lib
.keys()
.cloned()
.filter(|key| target_basis.contains(&key.name))
.filter(|key| target_basis.contains(key.name.as_str()))
.collect();

// Dummy node is inserted in the graph. Which is where the search will start
Expand Down Expand Up @@ -123,7 +123,7 @@ pub(crate) fn basis_search(
}
let edge_data = edge.weight().as_ref().unwrap();
let mut cost_tot = 0;
let borrowed_cost = opt_cost_map_cell.borrow();
let borrowed_cost = opt_cost_map.borrow();
for instruction in edge_data.rule.circuit.data.iter() {
cost_tot += borrowed_cost[&(instruction.op.name(), instruction.op.num_qubits())];
}
Expand All @@ -140,13 +140,13 @@ pub(crate) fn basis_search(
DijkstraEvent::Discover(n, score) => {
let gate_key = &equiv_lib.graph[n].key;
let gate = &(gate_key.name.as_str(), gate_key.num_qubits);
source_basis_remain.swap_remove(gate_key);
let mut borrowed_cost_map = opt_cost_map_cell.borrow_mut();
source_basis_remain.remove(gate_key);
let mut borrowed_cost_map = opt_cost_map.borrow_mut();
borrowed_cost_map
.entry(*gate)
.and_modify(|cost_ref| *cost_ref = score)
.or_insert(score);
if let Some(rule) = predecessors_cell.borrow().get(gate) {
if let Some(rule) = predecessors.borrow().get(gate) {
// TODO: Logger
basis_transforms.push((
gate_key.name.to_string(),
Expand All @@ -163,7 +163,7 @@ pub(crate) fn basis_search(
}
DijkstraEvent::EdgeRelaxed(_, target, Some(edata)) => {
let gate = &equiv_lib.graph[target].key;
predecessors_cell
predecessors
.borrow_mut()
.entry((gate.name.as_str(), gate.num_qubits))
.and_modify(|value| *value = edata.rule.clone())
Expand Down Expand Up @@ -200,7 +200,7 @@ pub(crate) fn basis_search(

fn initialize_num_gates_remain_for_rule(
graph: &StableDiGraph<NodeData, Option<EdgeData>>,
source: &mut IndexMap<usize, usize, RandomState>,
source: &mut HashMap<usize, usize>,
) {
let mut save_index = usize::MAX;
for edge_data in graph.edge_weights().flatten() {
Expand Down

0 comments on commit b7a40d7

Please sign in to comment.