-
Notifications
You must be signed in to change notification settings - Fork 368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing SamplerV2 #2090
Merged
Merged
Implementing SamplerV2 #2090
Changes from 3 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
f9fdcb8
Implementing SamplerV2
doichanj 3411e04
Merge remote-tracking branch 'upstream/main' into aer_primitives_v2
doichanj b7eda90
fix lint error
doichanj 0ab7a85
fix test
doichanj df621a8
fix test
doichanj 7c4c2b4
fix test
doichanj af09282
add options and function from_backend
doichanj 7f03171
lint
doichanj 1a8bc41
build test
doichanj 926627a
build test
doichanj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2022. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Sampler V2 class. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Iterable | ||
import warnings | ||
|
||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
from qiskit import ClassicalRegister, QiskitError, QuantumCircuit | ||
from qiskit.circuit import ControlFlowOp | ||
|
||
from qiskit.primitives.base import BaseSamplerV2 | ||
from qiskit.primitives.base.validation import _has_measure | ||
from qiskit.primitives.containers import ( | ||
BitArray, | ||
PrimitiveResult, | ||
PubResult, | ||
SamplerPubLike, | ||
make_data_bin, | ||
) | ||
from qiskit.primitives.containers.sampler_pub import SamplerPub | ||
from qiskit.primitives.containers.bit_array import _min_num_bytes | ||
from qiskit.primitives.primitive_job import PrimitiveJob | ||
|
||
from qiskit_aer import AerSimulator | ||
|
||
|
||
@dataclass | ||
class _MeasureInfo: | ||
creg_name: str | ||
num_bits: int | ||
num_bytes: int | ||
qreg_indices: list[int] | ||
|
||
|
||
class SamplerV2(BaseSamplerV2): | ||
""" | ||
Aer implementation of SamplerV2 class. | ||
|
||
Each tuple of ``(circuit, <optional> parameter values, <optional> shots)``, called a sampler | ||
primitive unified bloc (PUB), produces its own array-valued result. The :meth:`~run` method can | ||
be given many pubs at once. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
backend: AerSimulator = None, | ||
default_shots: int = 1024, | ||
seed: np.random.Generator | int | None = None, | ||
): | ||
""" | ||
Args: | ||
backend: AerSimulator object to be used for sampler | ||
default_shots: The default shots for the sampler if not specified during run. | ||
seed: The seed or Generator object for random number generation. | ||
If None, a random seeded default RNG will be used. | ||
""" | ||
self._default_shots = default_shots | ||
self._seed = seed | ||
|
||
if backend is None: | ||
self._backend = AerSimulator() | ||
else: | ||
self._backend = backend | ||
|
||
@property | ||
def default_shots(self) -> int: | ||
"""Return the default shots""" | ||
return self._default_shots | ||
|
||
@property | ||
def seed(self) -> np.random.Generator | int | None: | ||
"""Return the seed or Generator object for random number generation.""" | ||
return self._seed | ||
|
||
def run( | ||
self, pubs: Iterable[SamplerPubLike], *, shots: int | None = None | ||
) -> PrimitiveJob[PrimitiveResult[PubResult]]: | ||
if shots is None: | ||
shots = self._default_shots | ||
coerced_pubs = [SamplerPub.coerce(pub, shots) for pub in pubs] | ||
if any(len(pub.circuit.cregs) == 0 for pub in coerced_pubs): | ||
warnings.warn( | ||
"One of your circuits has no output classical registers and so the result " | ||
"will be empty. Did you mean to add measurement instructions?", | ||
UserWarning, | ||
) | ||
|
||
job = PrimitiveJob(self._run, coerced_pubs) | ||
job._submit() | ||
return job | ||
|
||
def _run(self, pubs: Iterable[SamplerPub]) -> PrimitiveResult[PubResult]: | ||
results = [self._run_pub(pub) for pub in pubs] | ||
return PrimitiveResult(results) | ||
|
||
def _run_pub(self, pub: SamplerPub) -> PubResult: | ||
circuit, qargs, meas_info = _preprocess_circuit(pub.circuit) | ||
|
||
# convert to parameter bindings | ||
parameter_values = pub.parameter_values | ||
parameter_binds = {} | ||
param_array = parameter_values.as_array(circuit.parameters) | ||
parameter_binds = {p: param_array[..., i].ravel() for i, p in enumerate(circuit.parameters)} | ||
|
||
arrays = { | ||
item.creg_name: np.zeros( | ||
parameter_values.shape + (pub.shots, item.num_bytes), dtype=np.uint8 | ||
) | ||
for item in meas_info | ||
} | ||
|
||
if qargs: | ||
circuit.measure_all() | ||
result = self._backend.run( | ||
circuit, | ||
shots=pub.shots, | ||
seed_simulator=self._seed, | ||
parameter_binds=[parameter_binds], | ||
).result() | ||
all_counts = result.get_counts() | ||
|
||
for index, counts in np.ndenumerate(all_counts): | ||
samples = [] | ||
for k, v in counts.items(): | ||
k = k.replace(" ", "") | ||
kk = "" | ||
for q in qargs: | ||
kk = k[circuit.num_qubits - 1 - q] + kk | ||
for _ in range(0, v): | ||
samples.append(kk) | ||
|
||
samples_array = np.array( | ||
[np.fromiter(sample, dtype=np.uint8) for sample in samples] | ||
) | ||
for item in meas_info: | ||
ary = _samples_to_packed_array(samples_array, item.num_bits, item.qreg_indices) | ||
arrays[item.creg_name][index] = ary | ||
else: | ||
for index in np.ndenumerate(parameter_values.shape): | ||
samples = [""] * pub.shots | ||
samples_array = np.array( | ||
[np.fromiter(sample, dtype=np.uint8) for sample in samples] | ||
) | ||
for item in meas_info: | ||
ary = _samples_to_packed_array(samples_array, item.num_bits, item.qreg_indices) | ||
arrays[item.creg_name][index] = ary | ||
|
||
data_bin_cls = make_data_bin( | ||
[(item.creg_name, BitArray) for item in meas_info], | ||
shape=parameter_values.shape, | ||
) | ||
meas = { | ||
item.creg_name: BitArray(arrays[item.creg_name], item.num_bits) for item in meas_info | ||
} | ||
data_bin = data_bin_cls(**meas) | ||
return PubResult(data_bin, metadata={"shots": pub.shots}) | ||
|
||
|
||
def _preprocess_circuit(circuit: QuantumCircuit): | ||
num_bits_dict = {creg.name: creg.size for creg in circuit.cregs} | ||
mapping = _final_measurement_mapping(circuit) | ||
qargs = sorted(set(mapping.values())) | ||
qargs_index = {v: k for k, v in enumerate(qargs)} | ||
circuit = circuit.remove_final_measurements(inplace=False) | ||
if _has_control_flow(circuit): | ||
raise QiskitError("StatevectorSampler cannot handle ControlFlowOp") | ||
if _has_measure(circuit): | ||
raise QiskitError("StatevectorSampler cannot handle mid-circuit measurements") | ||
# num_qubits is used as sentinel to fill 0 in _samples_to_packed_array | ||
sentinel = len(qargs) | ||
indices = {key: [sentinel] * val for key, val in num_bits_dict.items()} | ||
for key, qreg in mapping.items(): | ||
creg, ind = key | ||
indices[creg.name][ind] = qargs_index[qreg] | ||
meas_info = [ | ||
_MeasureInfo( | ||
creg_name=name, | ||
num_bits=num_bits, | ||
num_bytes=_min_num_bytes(num_bits), | ||
qreg_indices=indices[name], | ||
) | ||
for name, num_bits in num_bits_dict.items() | ||
] | ||
return circuit, qargs, meas_info | ||
|
||
|
||
def _samples_to_packed_array( | ||
samples: NDArray[np.uint8], num_bits: int, indices: list[int] | ||
) -> NDArray[np.uint8]: | ||
# samples of `Statevector.sample_memory` will be in the order of | ||
# qubit_last, ..., qubit_1, qubit_0. | ||
# reverse the sample order into qubit_0, qubit_1, ..., qubit_last and | ||
# pad 0 in the rightmost to be used for the sentinel introduced by _preprocess_circuit. | ||
ary = np.pad(samples[:, ::-1], ((0, 0), (0, 1)), constant_values=0) | ||
# place samples in the order of clbit_last, ..., clbit_1, clbit_0 | ||
ary = ary[:, indices[::-1]] | ||
# pad 0 in the left to align the number to be mod 8 | ||
# since np.packbits(bitorder='big') pads 0 to the right. | ||
pad_size = -num_bits % 8 | ||
ary = np.pad(ary, ((0, 0), (pad_size, 0)), constant_values=0) | ||
# pack bits in big endian order | ||
ary = np.packbits(ary, axis=-1) | ||
return ary | ||
|
||
|
||
def _final_measurement_mapping(circuit: QuantumCircuit) -> dict[tuple[ClassicalRegister, int], int]: | ||
"""Return the final measurement mapping for the circuit. | ||
|
||
Parameters: | ||
circuit: Input quantum circuit. | ||
|
||
Returns: | ||
Mapping of classical bits to qubits for final measurements. | ||
""" | ||
active_qubits = set(range(circuit.num_qubits)) | ||
active_cbits = set(range(circuit.num_clbits)) | ||
|
||
# Find final measurements starting in back | ||
mapping = {} | ||
for item in circuit[::-1]: | ||
if item.operation.name == "measure": | ||
loc = circuit.find_bit(item.clbits[0]) | ||
cbit = loc.index | ||
qbit = circuit.find_bit(item.qubits[0]).index | ||
if cbit in active_cbits and qbit in active_qubits: | ||
for creg in loc.registers: | ||
mapping[creg] = qbit | ||
active_cbits.remove(cbit) | ||
elif item.operation.name not in ["barrier", "delay"]: | ||
for qq in item.qubits: | ||
_temp_qubit = circuit.find_bit(qq).index | ||
if _temp_qubit in active_qubits: | ||
active_qubits.remove(_temp_qubit) | ||
|
||
if not active_cbits or not active_qubits: | ||
break | ||
|
||
return mapping | ||
|
||
|
||
def _has_control_flow(circuit: QuantumCircuit) -> bool: | ||
return any(isinstance(instruction.operation, ControlFlowOp) for instruction in circuit) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
features: | ||
- | | ||
Adding SamplerV2 implementation based on BaseSamplerV2 using AerSimulator. | ||
This is implemented based on qiskit.primitives.StatevectorSampler, but | ||
this can simulate using any simulation methods other than statevector | ||
by passing AerSimulator object with any methods. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you consider the options class and backend_options adopted by estimator for consistency. I am not sure which is better. Perhaps the Estimator side should be changed to match this specification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this option class comes from BackendSampler/Estimator.