Skip to content

Commit

Permalink
Merge 7ddad4b into aa37b5d
Browse files Browse the repository at this point in the history
  • Loading branch information
swernli authored Mar 14, 2024
2 parents aa37b5d + 7ddad4b commit 5441bdc
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 43 deletions.
5 changes: 4 additions & 1 deletion pip/qsharp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
set_quantum_seed,
set_classical_seed,
dump_machine,
StateDump,
ShotResult,
)

from ._native import Result, Pauli, QSharpError, TargetProfile, StateDump
from ._native import Result, Pauli, QSharpError, TargetProfile

# IPython notebook specific features
try:
Expand All @@ -39,4 +41,5 @@
"QSharpError",
"TargetProfile",
"StateDump",
"ShotResult",
]
10 changes: 4 additions & 6 deletions pip/qsharp/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class Interpreter:
the seed will be generated from entropy.
"""
...
def dump_machine(self) -> StateDump:
def dump_machine(self) -> StateDumpData:
"""
Returns the sparse state vector of the simulator as a StateDump object.
Expand Down Expand Up @@ -141,9 +141,9 @@ class Output:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def _repr_html_(self) -> str: ...
def state_dump(self) -> Optional[StateDump]: ...
def state_dump(self) -> Optional[StateDumpData]: ...

class StateDump:
class StateDumpData:
"""
A state dump returned from the Q# interpreter.
"""
Expand All @@ -155,11 +155,9 @@ class StateDump:

"""
Get the amplitudes of the state vector as a dictionary from state integer to
pair of real and imaginary amplitudes.
complex amplitudes.
"""
def get_dict(self) -> dict: ...
def __getitem__(self, index: int) -> Optional[Tuple[float, float]]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def _repr_html_(self) -> str: ...
Expand Down
42 changes: 39 additions & 3 deletions pip/qsharp/_qsharp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from ._native import Interpreter, TargetProfile, StateDump, QSharpError, Output
from ._native import Interpreter, TargetProfile, StateDumpData, QSharpError, Output
from typing import Any, Callable, Dict, Optional, TypedDict, Union, List
from .estimator._estimator import EstimatorResult, EstimatorParams
import json
Expand Down Expand Up @@ -82,7 +82,7 @@ def init(
raise QSharpError(
f"Error parsing {qsharp_json}. qsharp.json should exist at the project root and be a valid JSON file."
) from e

# if no features were passed in as an argument, use the features from the manifest.
# this way we prefer the features from the argument over those from the manifest.
if language_features == [] and manifest_descriptor != None:
Expand Down Expand Up @@ -273,10 +273,46 @@ def set_classical_seed(seed: Optional[int]) -> None:
"""
get_interpreter().set_classical_seed(seed)

class StateDump:
"""
A state dump returned from the Q# interpreter.
"""

"""
The number of allocated qubits at the time of the dump.
"""
qubit_count: int

__inner: dict
__data: StateDumpData

def __init__(self, data: StateDumpData):
self.__data = data
self.__inner = data.get_dict()
self.qubit_count = data.qubit_count

def __getitem__(self, index: int) -> complex:
return self.__inner.__getitem__(index)

def __iter__(self):
return self.__inner.__iter__()

def __len__(self) -> int:
return len(self.__inner)

def __repr__(self) -> str:
return self.__data.__repr__()

def __str__(self) -> str:
return self.__data.__str__()

def _repr_html_(self) -> str:
return self.__data._repr_html_()

def dump_machine() -> StateDump:
"""
Returns the sparse state vector of the simulator as a StateDump object.
:returns: The state of the simulator.
"""
return get_interpreter().dump_machine()
return StateDump(get_interpreter().dump_machine())
7 changes: 4 additions & 3 deletions pip/qsharp/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ def dump_operation(operation: str, num_qubits: int) -> List[List[complex]]:
:returns: The matrix representing the operation.
"""
code = f"""{{
let op : (Qubit[] => Unit) = {operation};
let op = {operation};
use (targets, extra) = (Qubit[{num_qubits}], Qubit[{num_qubits}]);
for i in 0..{num_qubits}-1 {{
H(targets[i]);
CNOT(targets[i], extra[i]);
}}
(op)(targets);
operation ApplyOp (op : (Qubit[] => Unit), targets : Qubit[]) : Unit {{ op(targets); }}
ApplyOp(op, targets);
Microsoft.Quantum.Diagnostics.DumpMachine();
ResetAll(targets + extra);
}}"""
Expand All @@ -39,5 +40,5 @@ def dump_operation(operation: str, num_qubits: int) -> List[List[complex]]:
if entry is None:
matrix[i] += [complex(0, 0)]
else:
matrix[i] += [complex(round(factor * entry[0], ndigits), round(factor *entry[1], ndigits))]
matrix[i] += [complex(round(factor * entry.real, ndigits), round(factor * entry.imag, ndigits))]
return matrix
34 changes: 12 additions & 22 deletions pip/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ use pyo3::{
exceptions::PyException,
prelude::*,
pyclass::CompareOp,
types::PyList,
types::{PyDict, PyString, PyTuple},
types::{PyComplex, PyDict, PyList, PyString, PyTuple},
};
use qsc::{
fir,
Expand All @@ -37,7 +36,7 @@ fn _native(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Result>()?;
m.add_class::<Pauli>()?;
m.add_class::<Output>()?;
m.add_class::<StateDump>()?;
m.add_class::<StateDumpData>()?;
m.add_function(wrap_pyfunction!(physical_estimates, m)?)?;
m.add("QSharpError", py.get_type::<QSharpError>())?;

Expand Down Expand Up @@ -179,9 +178,9 @@ impl Interpreter {
/// Dumps the quantum state of the interpreter.
/// Returns a tuple of (amplitudes, num_qubits), where amplitudes is a dictionary from integer indices to
/// pairs of real and imaginary amplitudes.
fn dump_machine(&mut self) -> StateDump {
fn dump_machine(&mut self) -> StateDumpData {
let (state, qubit_count) = self.interpreter.get_quantum_state();
StateDump(DisplayableState(state, qubit_count))
StateDumpData(DisplayableState(state, qubit_count))
}

fn run(
Expand Down Expand Up @@ -303,20 +302,20 @@ impl Output {
}
}

fn state_dump(&self) -> Option<StateDump> {
fn state_dump(&self) -> Option<StateDumpData> {
match &self.0 {
DisplayableOutput::State(state) => Some(StateDump(state.clone())),
DisplayableOutput::State(state) => Some(StateDumpData(state.clone())),
DisplayableOutput::Message(_) => None,
}
}
}

#[pyclass(unsendable)]
/// Captured simlation state dump.
pub(crate) struct StateDump(pub(crate) DisplayableState);
pub(crate) struct StateDumpData(pub(crate) DisplayableState);

#[pymethods]
impl StateDump {
impl StateDumpData {
fn get_dict(&self, py: Python) -> PyResult<Py<PyDict>> {
Ok(PyDict::from_sequence(
py,
Expand All @@ -328,7 +327,10 @@ impl StateDump {
.map(|(k, v)| {
PyTuple::new(
py,
&[k.clone().into_py(py), PyTuple::new(py, [v.re, v.im]).into()],
&[
k.clone().into_py(py),
PyComplex::from_doubles(py, v.re, v.im).into(),
],
)
})
.collect::<Vec<_>>(),
Expand All @@ -343,18 +345,6 @@ impl StateDump {
self.0 .1
}

// Pass by value is needed for compatiblity with the pyo3 API.
#[allow(clippy::needless_pass_by_value)]
fn __getitem__(&self, key: BigUint) -> Option<(f64, f64)> {
self.0 .0.iter().find_map(|state| {
if state.0 == key {
Some((state.1.re, state.1.im))
} else {
None
}
})
}

fn __len__(&self) -> usize {
self.0 .0.len()
}
Expand Down
8 changes: 3 additions & 5 deletions pip/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ def callback(output):
)
state_dump = e.dump_machine()
assert state_dump.qubit_count == 2
state_dump = state_dump.get_dict()
assert len(state_dump) == 1
assert state_dump[2][0] == 1.0
assert state_dump[2][1] == 0.0
state_dict = state_dump.get_dict()
assert state_dict[2][0] == 1.0
assert state_dict[2][1] == 0.0
assert state_dump[2].real == 1.0
assert state_dump[2].imag == 0.0


def test_error() -> None:
Expand Down
13 changes: 10 additions & 3 deletions pip/tests/test_qsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,19 @@ def test_dump_machine() -> None:
state_dump = qsharp.dump_machine()
assert state_dump.qubit_count == 2
assert len(state_dump) == 1
assert state_dump[2] == (1.0, 0.0)
assert state_dump[2] == complex(1.0, 0.0)
qsharp.eval("X(q2);")
state_dump = qsharp.dump_machine()
assert state_dump.qubit_count == 2
assert len(state_dump) == 1
assert state_dump[3] == (1.0, 0.0)
assert state_dump[3] == complex(1.0, 0.0)
qsharp.eval("H(q1);")
state_dump = qsharp.dump_machine()
assert state_dump.qubit_count == 2
assert len(state_dump) == 2
# Check that the state dump correctly supports iteration and membership checks
for idx in state_dump:
assert idx in state_dump

def test_dump_operation() -> None:
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
Expand All @@ -101,7 +108,7 @@ def test_dump_operation() -> None:
[complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(1.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0)],
[complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(1.0, 0.0)],
[complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(1.0, 0.0), complex(0.0, 0.0)]]
qsharp.eval("operation ApplySWAP(qs : Qubit[]) : Unit { SWAP(qs[0], qs[1]); }")
qsharp.eval("operation ApplySWAP(qs : Qubit[]) : Unit is Ctl + Adj { SWAP(qs[0], qs[1]); }")
res = qsharp.utils.dump_operation("ApplySWAP", 2)
assert res == [[complex(1.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0), complex(0.0, 0.0)],
[complex(0.0, 0.0), complex(0.0, 0.0), complex(1.0, 0.0), complex(0.0, 0.0)],
Expand Down

0 comments on commit 5441bdc

Please sign in to comment.