Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rust friendly assign parameters methods #12913

Merged
merged 11 commits into from
Aug 30, 2024
103 changes: 70 additions & 33 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,28 +953,16 @@ impl CircuitData {
sequence.py(),
array
.iter()
.map(|value| Param::Float(*value))
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, Param::Float(*value), uses)),
.map(|(value, (obj, uses))| (obj, value, uses)),
)
} else {
let values = sequence
.iter()?
.map(|ob| Param::extract_no_coerce(&ob?))
.collect::<PyResult<Vec<_>>>()?;
if values.len() != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a dictionary of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
self.assign_parameters_inner(
sequence.py(),
values
.into_iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, value, uses)),
)
self.assign_parameters_from_slice(sequence.py(), &values)
}
}

Expand Down Expand Up @@ -1135,6 +1123,50 @@ impl CircuitData {
self.data.iter()
}

/// Assigns parameters to circuit data based on a slice of `Param`.
pub fn assign_parameters_from_slice(&mut self, py: Python, slice: &[Param]) -> PyResult<()> {
if slice.len() != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a mapping of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
self.assign_parameters_inner(
py,
slice
.iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, value.clone_ref(py), uses)),
)
}

/// Assigns parameters to circuit data based on a mapping of `ParameterUuid` : `Param`.
/// This mapping assumes that the provided `ParameterUuid` keys are instances
/// of `ParameterExpression`.
Comment on lines +1144 to +1146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's conventional in Rust documentation to put the assumptions that cause panics in a section called # Panics to really make them stand out - idiomatic Rust is far stricter about errors than Python, where the conventions are more like "just raise an exception". Panics are serious business.

(But also see the other comment - imo it'd be cleaner just to handle the error properly, since it won't cost us anything.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

pub fn assign_parameters_from_mapping<I, T>(&mut self, py: Python, iter: I) -> PyResult<()>
where
I: IntoIterator<Item = (ParameterUuid, T)>,
T: AsRef<Param>,
{
let mut items = Vec::new();
for (param_uuid, value) in iter {
// Assume all the Parameters are already in the circuit
let param_obj = self.get_parameter_by_uuid(param_uuid);
if let Some(param_obj) = param_obj {
// Copy or increase ref_count for Parameter, avoid acquiring the GIL.
items.push((
param_obj.clone_ref(py),
value.as_ref().clone_ref(py),
self.param_table.pop(param_uuid)?,
));
} else {
return Err(PyValueError::new_err("An invalid parameter was provided."));
}
}
self.assign_parameters_inner(py, items)
}

/// Returns an immutable view of the Interner used for Qargs
pub fn qargs_interner(&self) -> &Interner<[Qubit]> {
&self.qargs_interner
Expand Down Expand Up @@ -1170,9 +1202,10 @@ impl CircuitData {
self.cargs_interner().get(index)
}

fn assign_parameters_inner<I>(&mut self, py: Python, iter: I) -> PyResult<()>
fn assign_parameters_inner<I, T>(&mut self, py: Python, iter: I) -> PyResult<()>
where
I: IntoIterator<Item = (Py<PyAny>, Param, HashSet<ParameterUse>)>,
I: IntoIterator<Item = (Py<PyAny>, T, HashSet<ParameterUse>)>,
T: AsRef<Param> + Clone,
{
let inconsistent =
|| PyRuntimeError::new_err("internal error: circuit parameter table is inconsistent");
Expand Down Expand Up @@ -1209,7 +1242,7 @@ impl CircuitData {
for (param_ob, value, uses) in iter {
debug_assert!(!uses.is_empty());
uuids.clear();
for inner_param_ob in value.iter_parameters(py)? {
for inner_param_ob in value.as_ref().iter_parameters(py)? {
uuids.push(self.param_table.track(&inner_param_ob?, None)?)
}
for usage in uses {
Expand All @@ -1220,7 +1253,7 @@ impl CircuitData {
};
self.set_global_phase(
py,
bind_expr(expr.bind_borrowed(py), &param_ob, &value, true)?,
bind_expr(expr.bind_borrowed(py), &param_ob, value.as_ref(), true)?,
)?;
}
ParameterUse::Index {
Expand All @@ -1234,17 +1267,21 @@ impl CircuitData {
let Param::ParameterExpression(expr) = &params[parameter] else {
return Err(inconsistent());
};
params[parameter] =
match bind_expr(expr.bind_borrowed(py), &param_ob, &value, true)? {
Param::Obj(obj) => {
return Err(CircuitError::new_err(format!(
"bad type after binding for gate '{}': '{}'",
standard.name(),
obj.bind(py).repr()?,
)))
}
param => param,
};
params[parameter] = match bind_expr(
expr.bind_borrowed(py),
&param_ob,
value.as_ref(),
true,
)? {
Param::Obj(obj) => {
return Err(CircuitError::new_err(format!(
"bad type after binding for gate '{}': '{}'",
standard.name(),
obj.bind(py).repr()?,
)))
}
param => param,
};
for uuid in uuids.iter() {
self.param_table.add_use(*uuid, usage)?
}
Expand All @@ -1264,7 +1301,7 @@ impl CircuitData {
user_operations
.entry(instruction)
.or_insert_with(Vec::new)
.push((param_ob.clone_ref(py), value.clone()));
.push((param_ob.clone_ref(py), value.as_ref().clone_ref(py)));

let op = previous.unpack_py_op(py)?.into_bound(py);
let previous_param = &previous.params_view()[parameter];
Expand All @@ -1276,7 +1313,7 @@ impl CircuitData {
let new_param = bind_expr(
expr.bind_borrowed(py),
&param_ob,
&value,
value.as_ref(),
false,
)?;
// Historically, `assign_parameters` called `validate_parameter`
Expand Down Expand Up @@ -1305,7 +1342,7 @@ impl CircuitData {
Param::extract_no_coerce(
&obj.call_method(
assign_parameters_attr,
([(&param_ob, &value)].into_py_dict_bound(py),),
([(&param_ob, value.as_ref())].into_py_dict_bound(py),),
Some(
&[("inplace", false), ("flat_input", true)]
.into_py_dict_bound(py),
Expand Down
15 changes: 15 additions & 0 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,21 @@ impl Param {
Param::Obj(ob.clone().unbind())
})
}

/// Clones the [Param] object safely by reference count or copying.
pub fn clone_ref(&self, py: Python) -> Self {
match self {
Param::ParameterExpression(exp) => Param::ParameterExpression(exp.clone_ref(py)),
Param::Float(float) => Param::Float(*float),
Param::Obj(obj) => Param::Obj(obj.clone_ref(py)),
}
}
}

impl AsRef<Param> for Param {
fn as_ref(&self) -> &Param {
self
}
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
}

/// Struct to provide iteration over Python-space `Parameter` instances within a `Param`.
Expand Down
Loading