diff --git a/crates/circuit/src/circuit_data.rs b/crates/circuit/src/circuit_data.rs index dd2618f03f9a..dc83c9dd2487 100644 --- a/crates/circuit/src/circuit_data.rs +++ b/crates/circuit/src/circuit_data.rs @@ -343,7 +343,7 @@ impl CircuitData { /// Get a (cached) sorted list of the Python-space `Parameter` instances tracked by this circuit /// data's parameter table. #[getter] - pub fn get_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> { + pub fn get_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> { self.param_table.py_parameters(py) } diff --git a/crates/circuit/src/parameter_table.rs b/crates/circuit/src/parameter_table.rs index 8825fbd71772..e637d4968a19 100644 --- a/crates/circuit/src/parameter_table.rs +++ b/crates/circuit/src/parameter_table.rs @@ -10,6 +10,8 @@ // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. +use std::cell::RefCell; + use hashbrown::hash_map::Entry; use hashbrown::{HashMap, HashSet}; use thiserror::Error; @@ -114,6 +116,30 @@ impl<'py> FromPyObject<'py> for VectorUuid { } } +#[derive(Clone, Default, Debug)] +struct ParameterTableOrder { + /// The Rust-space sort order. + uuids: Vec, + /// Cache of a Python-space list of the parameter objects, in order. We only generate this + /// specifically when asked. + py_parameters: Option>, +} + +impl ParameterTableOrder { + fn uuids(&self) -> Option<&[ParameterUuid]> { + (!self.uuids.is_empty()).then_some(self.uuids.as_slice()) + } + + fn py_parameters(&self) -> Option<&Py> { + self.py_parameters.as_ref() + } + + fn invalidate(&mut self) { + self.uuids.clear(); + self.py_parameters = None; + } +} + #[derive(Clone, Default, Debug)] pub struct ParameterTable { /// Mapping of the parameter key (its UUID) to the information on it tracked by this table. @@ -123,18 +149,14 @@ pub struct ParameterTable { by_name: HashMap, /// Additional information on any `ParameterVector` instances that have elements in the circuit. vectors: HashMap, - /// Sort order of the parameters. This is lexicographical for most parameters, except elements - /// of a `ParameterVector` are sorted within the vector by numerical index. We calculate this - /// on demand and cache it; an empty `order` implies it is not currently calculated. We don't - /// use `Option` so we can re-use the allocation for partial parameter bindings. - /// - /// Any method that adds or a removes a parameter is responsible for invalidating this cache. - order: Vec, - /// Cache of a Python-space list of the parameter objects, in order. We only generate this - /// specifically when asked. + /// Cache related to the sort order of the parameters. This is lexicographical for most + /// parameters, except elements of a `ParameterVector` are sorted within the vector by numerical + /// index. We calculate this on demand and cache it; an empty `order` implies it is not + /// currently calculated. We don't use `Option` so we can re-use the allocation for + /// partial parameter bindings. /// - /// Any method that adds or a removes a parameter is responsible for invalidating this cache. - py_parameters: Option>, + /// Any method that adds or removes a parameter needs to invalidate this. + order_cache: RefCell, } impl ParameterTable { @@ -194,8 +216,7 @@ impl ParameterTable { None }; self.by_name.insert(name.clone(), uuid); - self.order.clear(); - self.py_parameters = None; + self.order_cache.borrow_mut().invalidate(); let mut uses = HashSet::new(); if let Some(usage) = usage { uses.insert_unique_unchecked(usage); @@ -226,18 +247,32 @@ impl ParameterTable { } /// Get the (maybe cached) Python list of the sorted `Parameter` objects. - pub fn py_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> { - if let Some(py_parameters) = self.py_parameters.as_ref() { + pub fn py_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> { + if let Some(py_parameters) = self.order_cache.borrow().py_parameters() { return py_parameters.clone_ref(py).into_bound(py); } - self.ensure_sorted(); - let out = PyList::new_bound( - py, - self.order - .iter() - .map(|uuid| self.by_uuid[uuid].object.clone_ref(py).into_bound(py)), - ); - self.py_parameters = Some(out.clone().unbind()); + let make_parameters = |order: &[ParameterUuid]| { + PyList::new_bound( + py, + order + .iter() + .map(|uuid| self.by_uuid[uuid].object.bind(py).clone()), + ) + }; + let out = match self.order_cache.borrow().uuids() { + Some(uuids) => make_parameters(uuids), + None => { + let uuids = self.sorted_order(); + let out = make_parameters(&uuids); + if let Ok(mut cache) = self.order_cache.try_borrow_mut() { + cache.uuids = uuids; + } + out + } + }; + if let Ok(mut cache) = self.order_cache.try_borrow_mut() { + cache.py_parameters = Some(out.clone().unbind()); + } out } @@ -246,23 +281,18 @@ impl ParameterTable { PySet::new_bound(py, self.by_uuid.values().map(|info| &info.object)) } - /// Ensure that the `order` field is populated and sorted. - fn ensure_sorted(&mut self) { - // If `order` is already populated, it's sorted; it's the responsibility of the methods of - // this struct that mutate it to invalidate the cache. - if !self.order.is_empty() { - return; - } - self.order.reserve(self.by_uuid.len()); - self.order.extend(self.by_uuid.keys()); - self.order.sort_unstable_by_key(|uuid| { + /// Get the sorted order of the `ParameterTable`. This does not access the cache. + fn sorted_order(&self) -> Vec { + let mut out = self.by_uuid.keys().copied().collect::>(); + out.sort_unstable_by_key(|uuid| { let info = &self.by_uuid[uuid]; if let Some(vec) = info.element.as_ref() { (&self.vectors[&vec.vector_uuid].name, vec.index) } else { (&info.name, 0) } - }) + }); + out } /// Add a use of a parameter to the table. @@ -305,8 +335,7 @@ impl ParameterTable { vec_entry.remove_entry(); } } - self.order.clear(); - self.py_parameters = None; + self.order_cache.borrow_mut().invalidate(); entry.remove_entry(); } Ok(()) @@ -332,26 +361,30 @@ impl ParameterTable { (vector_info.refcount > 0).then_some(vector_info) }); } - self.order.clear(); - self.py_parameters = None; + self.order_cache.borrow_mut().invalidate(); Ok(info.uses) } /// Clear this table, yielding the Python parameter objects and their uses in sorted order. + /// + /// The clearing effect is eager and not dependent on the iteration. pub fn drain_ordered( - &'_ mut self, - ) -> impl Iterator, HashSet)> + '_ { - self.ensure_sorted(); + &mut self, + ) -> impl ExactSizeIterator, HashSet)> { + let mut cache = self.order_cache.borrow_mut(); + cache.py_parameters = None; + let order = if cache.uuids.is_empty() { + self.sorted_order() + } else { + ::std::mem::take(&mut cache.uuids) + }; + let by_uuid = ::std::mem::take(&mut self.by_uuid); self.by_name.clear(); self.vectors.clear(); - self.py_parameters = None; - self.order.drain(..).map(|uuid| { - let info = self - .by_uuid - .remove(&uuid) - .expect("tracked UUIDs should be consistent"); - (info.object, info.uses) - }) + ParameterTableDrain { + order: order.into_iter(), + by_uuid, + } } /// Empty this `ParameterTable` of all its contents. This does not affect the capacities of the @@ -360,8 +393,7 @@ impl ParameterTable { self.by_uuid.clear(); self.by_name.clear(); self.vectors.clear(); - self.order.clear(); - self.py_parameters = None; + self.order_cache.borrow_mut().invalidate(); } /// Expose the tracked data for a given parameter as directly as possible to Python space. @@ -396,9 +428,33 @@ impl ParameterTable { visit.call(&info.object)? } // We don't need to / can't visit the `PyBackedStr` stores. - if let Some(list) = self.py_parameters.as_ref() { + if let Some(list) = self.order_cache.borrow().py_parameters() { visit.call(list)? } Ok(()) } } + +struct ParameterTableDrain { + order: ::std::vec::IntoIter, + by_uuid: HashMap, +} +impl Iterator for ParameterTableDrain { + type Item = (Py, HashSet); + + fn next(&mut self) -> Option { + self.order.next().map(|uuid| { + let info = self + .by_uuid + .remove(&uuid) + .expect("tracked UUIDs should be consistent"); + (info.object, info.uses) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.order.size_hint() + } +} +impl ExactSizeIterator for ParameterTableDrain {} +impl ::std::iter::FusedIterator for ParameterTableDrain {}