Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shots option to primitives #8137

Merged
merged 6 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions qiskit/primitives/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,18 @@

class Estimator(BaseEstimator):
"""
Estimator class
Reference implementation of :class:`BaseEstimator`.

:Run Options:

- **shots** (None or int) --
The number of shots. If None, it calculates the exact expectation
values. Otherwise, it samples from normal distributions with standard errors as standard
deviations using normal distribution approximation.

- **seed** (np.random.Generator or int) --
Set a fixed seed or generator for the normal distribution. If shots is None,
this option is ignored.
"""

def __init__(
Expand Down Expand Up @@ -66,6 +77,18 @@ def _call(
if self._is_closed:
raise QiskitError("The primitive has been closed.")

shots = run_options.pop("shots", None)
seed = run_options.pop("seed", None)
if seed is None:
rng = np.random.default_rng()
elif isinstance(seed, np.random.Generator):
rng = seed
else:
rng = np.random.default_rng(seed)

# Initialize metadata
metadata = [{}] * len(circuits)

bound_circuits = []
for i, value in zip(circuits, parameter_values):
if len(value) != len(self._parameters[i]):
Expand All @@ -78,15 +101,28 @@ def _call(
)
sorted_observables = [self._observables[i] for i in observables]
expectation_values = []
for circ, obs in zip(bound_circuits, sorted_observables):
for circ, obs, metadatum in zip(bound_circuits, sorted_observables, metadata):
if circ.num_qubits != obs.num_qubits:
raise QiskitError(
f"The number of qubits of a circuit ({circ.num_qubits}) does not match "
f"the number of qubits of a observable ({obs.num_qubits})."
)
expectation_values.append(Statevector(circ).expectation_value(obs))
final_state = Statevector(circ)
expectation_value = final_state.expectation_value(obs)
if shots is None:
expectation_values.append(expectation_value)
else:
expectation_value = np.real_if_close(expectation_value)
sq_obs = (obs @ obs).simplify()
sq_exp_val = np.real_if_close(final_state.expectation_value(sq_obs))
variance = sq_exp_val - expectation_value**2
standard_deviation = np.sqrt(variance / shots)
ikkoham marked this conversation as resolved.
Show resolved Hide resolved
expectation_value_with_error = rng.normal(expectation_value, standard_deviation)
expectation_values.append(expectation_value_with_error)
metadatum["variance"] = variance
metadatum["shots"] = shots

return EstimatorResult(np.real_if_close(expectation_values), [{}] * len(expectation_values))
return EstimatorResult(np.real_if_close(expectation_values), metadata)

def close(self):
self._is_closed = True
36 changes: 34 additions & 2 deletions qiskit/primitives/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from collections.abc import Iterable, Sequence

import numpy as np

from qiskit.circuit import Parameter, QuantumCircuit
from qiskit.exceptions import QiskitError
from qiskit.quantum_info import Statevector
Expand All @@ -29,7 +31,19 @@

class Sampler(BaseSampler):
"""
Sampler class
Sampler class.

:class:`~Sampler` is a reference implementation of :class:`~BaseSampler`.

:Run Options:

- **shots** (None or int) --
The number of shots. If None, it calculates the probabilities.
Otherwise, it samples from multinomial distributions.

- **seed** (np.random.Generator or int) --
Set a fixed seed or generator for the multinomial distribution. If shots is None, this
option is ignored.
"""

def __init__(
Expand Down Expand Up @@ -73,6 +87,18 @@ def _call(
if self._is_closed:
raise QiskitError("The primitive has been closed.")

shots = run_options.pop("shots", None)
seed = run_options.pop("seed", None)
if seed is None:
rng = np.random.default_rng()
elif isinstance(seed, np.random.Generator):
rng = seed
else:
rng = np.random.default_rng(seed)

# Initialize metadata
metadata = [{}] * len(circuits)

bound_circuits_qargs = []
for i, value in zip(circuits, parameter_values):
if len(value) != len(self._parameters[i]):
Expand All @@ -89,9 +115,15 @@ def _call(
probabilities = [
Statevector(circ).probabilities(qargs=qargs) for circ, qargs in bound_circuits_qargs
]
if shots is not None:
probabilities = [
rng.multinomial(shots, probability) / shots for probability in probabilities
]
for metadatum in metadata:
metadatum["shots"] = shots
quasis = [QuasiDistribution(dict(enumerate(p))) for p in probabilities]

return SamplerResult(quasis, [{}] * len(circuits))
return SamplerResult(quasis, metadata)

def close(self):
self._is_closed = True
17 changes: 17 additions & 0 deletions releasenotes/notes/primitive-shots-option-ed320872d048483e.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
features:
- |
Added ``shots`` option for reference implementations of primitives.
Random numbers can be fixed by giving ``seed_primitive``. For example::

from qiskit.primitives import Sampler
from qiskit import QuantumCircuit

bell = QuantumCircuit(2)
bell.h(0)
bell.cx(0, 1)
bell.measure_all()

with Sampler(circuits=[bell]) as sampler:
result = sampler(circuits=[0], shots=1024, seed_primitive=15)
print([q.binary_probabilities() for q in result.quasi_dists])
14 changes: 14 additions & 0 deletions test/python/primitives/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,20 @@ def test_deprecated_arguments(self):
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [-1.284366511861733])

def test_with_shots_option(self):
"""test with shots option."""
with Estimator([self.ansatz], [self.observable]) as est:
result = est([0], [0], parameter_values=[[0, 1, 1, 2, 3, 5]], shots=1024, seed=15)
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [-1.307397243478641])

ikkoham marked this conversation as resolved.
Show resolved Hide resolved
def test_with_shots_option_none(self):
"""test with shots=None option. Seed is ignored then."""
with Estimator([self.ansatz], [self.observable]) as est:
result_42 = est([0], [0], parameter_values=[[0, 1, 1, 2, 3, 5]], shots=None, seed=42)
result_15 = est([0], [0], parameter_values=[[0, 1, 1, 2, 3, 5]], shots=None, seed=15)
np.testing.assert_allclose(result_42.values, result_15.values)


if __name__ == "__main__":
unittest.main()
18 changes: 16 additions & 2 deletions test/python/primitives/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

"""Tests for Sampler."""

from test import combine
import unittest
from test import combine

from ddt import ddt
import numpy as np
from ddt import ddt

from qiskit import QuantumCircuit
from qiskit.circuit import Parameter
Expand Down Expand Up @@ -407,6 +407,20 @@ def test_deprecated_circuit_indices(self, indices):
)
self._compare_probs(result.quasi_dists, target)

def test_with_shots_option(self):
"""test with shots option."""
params, target = self._generate_params_target([1])
with Sampler(circuits=self._pqc) as sampler:
result = sampler(circuits=[0], parameter_values=params, shots=1024, seed=15)
self._compare_probs(result.quasi_dists, target)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as #8137 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 1593f34

def test_with_shots_option_none(self):
"""test with shots=None option. Seed is ignored then."""
with Sampler([self._pqc]) as sampler:
result_42 = sampler([0], parameter_values=[[0, 1, 1, 2, 3, 5]], shots=None, seed=42)
result_15 = sampler([0], parameter_values=[[0, 1, 1, 2, 3, 5]], shots=None, seed=15)
self.assertDictAlmostEqual(result_42.quasi_dists, result_15.quasi_dists)


if __name__ == "__main__":
unittest.main()