Skip to content

Commit

Permalink
feat(python): ConnectionStrategy supports the pickle module (#486)
Browse files Browse the repository at this point in the history
* feat(python): ConnectionStrategy supports the pickle module

* less constricted quil requirement

* fix missing assert, remove stale code
  • Loading branch information
MarquessV authored Jul 29, 2024
1 parent 00e927e commit 4df4ce3
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 18 deletions.
2 changes: 1 addition & 1 deletion crates/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Operating System :: OS Independent",
]
dependencies = ["quil==0.11.1"]
dependencies = ["quil>=0.11.2"]

# PEP 621 specifies the [project] table as the source for project metadata. However, Poetry only supports [tool.poetry]
# We can remove this table once this issue is resolved: https://github.com/python-poetry/poetry/issues/3332
Expand Down
68 changes: 51 additions & 17 deletions crates/python/src/qpu/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use pyo3::{
pyclass,
pyclass::CompareOp,
pyfunction, pymethods,
types::{PyComplex, PyDict, PyInt},
IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
types::{PyComplex, PyInt, PyTuple},
IntoPy, Py, PyObject, PyResult, Python, ToPyObject,
};
use qcs::qpu::api::{
ApiExecutionOptions, ApiExecutionOptionsBuilder, ConnectionStrategy, ExecutionOptions,
Expand Down Expand Up @@ -382,6 +382,7 @@ py_function_sync_async! {

py_wrap_type! {
#[derive(Debug, Default)]
#[pyo3(module = "qcs_sdk.qpu.api")]
PyExecutionOptions(ExecutionOptions) as "ExecutionOptions"
}
impl_repr!(PyExecutionOptions);
Expand Down Expand Up @@ -432,21 +433,22 @@ impl PyExecutionOptions {
}
}

fn __getstate__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let dict = PyDict::new(py);
dict.set_item("connection_strategy", self.connection_strategy())?;
dict.set_item("timeout_seconds", self.timeout_seconds())?;
dict.set_item("api_options", self.api_options())?;
Ok(dict.into())
}

fn __setstate__(&mut self, py: Python<'_>, state: Py<PyAny>) -> PyResult<()> {
*self = Self::_from_parts(
state.getattr(py, "connection_strategy")?.extract(py)?,
state.getattr(py, "timeout_seconds")?.extract(py)?,
state.getattr(py, "api_options")?.extract(py)?,
)?;
Ok(())
fn __reduce__<'py>(&mut self, py: Python<'py>) -> PyResult<&'py PyTuple> {
let callable = py.get_type::<Self>().getattr("_from_parts")?;
Ok(PyTuple::new(
py,
[
callable,
PyTuple::new(
py,
&[
self.connection_strategy().into_py(py),
self.timeout_seconds().into_py(py),
self.api_options().into_py(py),
],
),
],
))
}

#[staticmethod]
Expand Down Expand Up @@ -575,6 +577,7 @@ impl PyApiExecutionOptionsBuilder {

py_wrap_type! {
#[derive(Default)]
#[pyo3(module = "qcs_sdk.qpu.api")]
PyConnectionStrategy(ConnectionStrategy) as "ConnectionStrategy"
}
impl_repr!(PyConnectionStrategy);
Expand Down Expand Up @@ -608,4 +611,35 @@ impl PyConnectionStrategy {
_ => py.NotImplemented(),
}
}

fn __reduce__(&self, py: Python<'_>) -> PyResult<PyObject> {
Ok(match self.as_inner() {
ConnectionStrategy::Gateway => PyTuple::new(
py,
&[
py.get_type::<Self>().getattr("gateway")?.to_object(py),
PyTuple::empty(py).to_object(py),
],
)
.to_object(py),
ConnectionStrategy::DirectAccess => PyTuple::new(
py,
&[
py.get_type::<Self>()
.getattr("direct_access")?
.to_object(py),
PyTuple::empty(py).to_object(py),
],
)
.to_object(py),
ConnectionStrategy::EndpointId(endpoint_id) => PyTuple::new(
py,
&[
py.get_type::<Self>().getattr("endpoint_id")?.to_object(py),
PyTuple::new(py, [endpoint_id]).to_object(py),
],
)
.to_object(py),
})
}
}
16 changes: 16 additions & 0 deletions crates/python/tests/qpu/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pickle
import pytest

from qcs_sdk.qpu.translation import (
translate,
)

from qcs_sdk.qpu.api import (
ConnectionStrategy,
ExecutionOptions,
Register,
retrieve_results,
submit,
Expand Down Expand Up @@ -44,3 +47,16 @@ def test_submit_retrieve(

job_id = submit(program, memory, quantum_processor_id)
results = retrieve_results(job_id)

class TestPickle():
@pytest.mark.parametrize("strategy", [ConnectionStrategy.gateway(), ConnectionStrategy.direct_access(), ConnectionStrategy.endpoint_id("endpoint_id")])
def test_connection_strategy(self, strategy: ConnectionStrategy):
pickled = pickle.dumps(strategy)
unpickled = pickle.loads(pickled)
assert unpickled == strategy

def test_execution_options(self):
options = ExecutionOptions.default()
pickled = pickle.dumps(options)
unpickled = pickle.loads(pickled)
assert unpickled == options

0 comments on commit 4df4ce3

Please sign in to comment.