Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rust friendly assign parameters methods #12913

Merged
merged 11 commits into from
Aug 30, 2024
89 changes: 60 additions & 29 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,45 +949,42 @@ impl CircuitData {
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
let owned_iter: Vec<Param> = array.iter().map(|value| Param::Float(*value)).collect();
self.assign_parameters_inner(
sequence.py(),
array
owned_iter
.iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, Param::Float(*value), uses)),
.map(|(value, (obj, uses))| (obj, value, uses)),
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
)
} else {
let values = sequence
.iter()?
.map(|ob| Param::extract_no_coerce(&ob?))
.collect::<PyResult<Vec<_>>>()?;
if values.len() != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a dictionary of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
self.assign_parameters_inner(
sequence.py(),
values
.into_iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, value, uses)),
)
self.assign_parameters_from_slice(sequence.py(), &values)
}
}

/// Assign all uses of the circuit parameters as keys `mapping` to their corresponding values.
fn assign_parameters_mapping(&mut self, mapping: Bound<PyAny>) -> PyResult<()> {
let py = mapping.py();
let mut items = Vec::new();
let mut objs = Vec::new();
for item in mapping.call_method0("items")?.iter()? {
let (param_ob, value) = item?.extract::<(Py<PyAny>, AssignParam)>()?;
let uuid = ParameterUuid::from_parameter(param_ob.bind(py))?;
items.push((param_ob, value.0, self.param_table.pop(uuid)?));
items.push(value);
objs.push(param_ob); // We need to separate the objects to avoid cloning.
}
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
self.assign_parameters_inner(py, items)
let borrowed_iterator: PyResult<Vec<_>> = items
.iter()
.zip(objs.into_iter())
.map(|(value, param_obj)| -> PyResult<_> {
let uuid = ParameterUuid::from_parameter(param_obj.bind(py))?;
Ok((param_obj, &value.0, self.param_table.pop(uuid)?))
})
.collect();
self.assign_parameters_inner(py, borrowed_iterator?)
}

pub fn clear(&mut self) {
Expand Down Expand Up @@ -1135,6 +1132,44 @@ impl CircuitData {
self.data.iter()
}

/// Assigns parameters to circuit data based on a slice of `Param`.
pub fn assign_parameters_from_slice(&mut self, py: Python, slice: &[Param]) -> PyResult<()> {
if slice.len() != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a dictionary of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
self.assign_parameters_inner(
py,
slice
.iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, value, uses)),
)
}

/// Assigns parameters to circuit data based on a mapping of `ParameterUuid` : `Param`.
/// This mapping assumes that the provided `ParameterUuid` keys are instances
/// of `ParameterExpression`.
Comment on lines +1144 to +1146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's conventional in Rust documentation to put the assumptions that cause panics in a section called # Panics to really make them stand out - idiomatic Rust is far stricter about errors than Python, where the conventions are more like "just raise an exception". Panics are serious business.

(But also see the other comment - imo it'd be cleaner just to handle the error properly, since it won't cost us anything.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

pub fn assign_parameters_from_mapping<'a, I>(&mut self, py: Python, iter: I) -> PyResult<()>
where
I: IntoIterator<Item = (ParameterUuid, &'a Param)>,
{
let mut items = Vec::new();
for (param_uuid, value) in iter {
// Assume all the Parameters are already in the circuit
let param_obj = self.get_parameter_by_uuid(param_uuid);
items.push((
param_obj.unwrap().clone_ref(py),
value,
self.param_table.pop(param_uuid)?,
));
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
}
self.assign_parameters_inner(py, items)
}

/// Returns an immutable view of the Interner used for Qargs
pub fn qargs_interner(&self) -> &Interner<[Qubit]> {
&self.qargs_interner
Expand Down Expand Up @@ -1170,9 +1205,9 @@ impl CircuitData {
self.cargs_interner().get(index)
}

fn assign_parameters_inner<I>(&mut self, py: Python, iter: I) -> PyResult<()>
fn assign_parameters_inner<'a, I>(&mut self, py: Python, iter: I) -> PyResult<()>
where
I: IntoIterator<Item = (Py<PyAny>, Param, HashSet<ParameterUse>)>,
I: IntoIterator<Item = (Py<PyAny>, &'a Param, HashSet<ParameterUse>)>,
{
let inconsistent =
|| PyRuntimeError::new_err("internal error: circuit parameter table is inconsistent");
Expand Down Expand Up @@ -1220,7 +1255,7 @@ impl CircuitData {
};
self.set_global_phase(
py,
bind_expr(expr.bind_borrowed(py), &param_ob, &value, true)?,
bind_expr(expr.bind_borrowed(py), &param_ob, value, true)?,
)?;
}
ParameterUse::Index {
Expand All @@ -1235,7 +1270,7 @@ impl CircuitData {
return Err(inconsistent());
};
params[parameter] =
match bind_expr(expr.bind_borrowed(py), &param_ob, &value, true)? {
match bind_expr(expr.bind_borrowed(py), &param_ob, value, true)? {
Param::Obj(obj) => {
return Err(CircuitError::new_err(format!(
"bad type after binding for gate '{}': '{}'",
Expand Down Expand Up @@ -1273,12 +1308,8 @@ impl CircuitData {
Param::ParameterExpression(expr) => {
// For user gates, we don't coerce floats to integers in `Param`
// so that users can use them if they choose.
let new_param = bind_expr(
expr.bind_borrowed(py),
&param_ob,
&value,
false,
)?;
let new_param =
bind_expr(expr.bind_borrowed(py), &param_ob, value, false)?;
// Historically, `assign_parameters` called `validate_parameter`
// only when a `ParameterExpression` became fully bound. Some
// "generalised" (or user) gates fail without this, though
Expand Down
Loading