Skip to content

Commit

Permalink
Fix the AmplitudeEstimator algorithms for primitive usage (Qiskit…
Browse files Browse the repository at this point in the history
…#9394)

* Fix the Amplitude Estimation algorithms with primitives

* fix bernoulli tests

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
Cryoris and mergify[bot] authored Jan 20, 2023
1 parent d29345e commit 66c884e
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 56 deletions.
9 changes: 3 additions & 6 deletions qiskit/algorithms/amplitude_estimators/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,12 @@ def estimate(self, estimation_problem: EstimationProblem) -> "AmplitudeEstimatio

shots = ret.metadata[0].get("shots")
if shots is None:
result.circuit_results = {
np.binary_repr(k, circuit.num_qubits): v
for k, v in ret.quasi_dists[0].items()
}
result.circuit_results = ret.quasi_dists[0].binary_probabilities()
shots = 1
else:
result.circuit_results = {
np.binary_repr(k, circuit.num_qubits): round(v * shots)
for k, v in ret.quasi_dists[0].items()
k: round(v * shots)
for k, v in ret.quasi_dists[0].binary_probabilities().items()
}

# store shots
Expand Down
16 changes: 0 additions & 16 deletions qiskit/algorithms/amplitude_estimators/ae_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,6 @@
# pylint: disable=invalid-name


def _probabilities_from_sampler_result(num_qubits, result, estimation_problem):
"""calculate probabilities from sampler result"""
prob = 0
for bit, probabilities in result.quasi_dists[0].items():
i = int(bit)
# get bitstring of objective qubits
full_state = bin(i)[2:].zfill(num_qubits)[::-1]
state = "".join([full_state[i] for i in estimation_problem.objective_qubits])

# check if it is a good state
if estimation_problem.is_good_state(state[::-1]):
prob += probabilities

return prob


def bisect_max(f, a, b, steps=50, minwidth=1e-12, retval=False):
"""Find the maximum of the real-valued function f in the interval [a, b] using bisection.
Expand Down
11 changes: 7 additions & 4 deletions qiskit/algorithms/amplitude_estimators/fae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from .amplitude_estimator import AmplitudeEstimator, AmplitudeEstimatorResult
from .estimation_problem import EstimationProblem
from .ae_utils import _probabilities_from_sampler_result


class FasterAmplitudeEstimation(AmplitudeEstimator):
Expand Down Expand Up @@ -156,10 +155,14 @@ def _cos_estimate(self, estimation_problem, k, shots):
if shots is None:
shots = 1
self._num_oracle_calls += (2 * k + 1) * shots

# sum over all probabilities where the objective qubits are 1
prob = _probabilities_from_sampler_result(
circuit.num_qubits, result, estimation_problem
)
prob = 0
for bit, probabilities in result.quasi_dists[0].binary_probabilities().items():
# check if it is a good state
if estimation_problem.is_good_state(bit):
prob += probabilities

cos_estimate = 1 - 2 * prob
elif self._quantum_instance.is_statevector:
circuit = self.construct_circuit(estimation_problem, k, measurement=False)
Expand Down
16 changes: 7 additions & 9 deletions qiskit/algorithms/amplitude_estimators/iae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from .amplitude_estimator import AmplitudeEstimator, AmplitudeEstimatorResult
from .estimation_problem import EstimationProblem
from .ae_utils import _probabilities_from_sampler_result
from ..exceptions import AlgorithmError


Expand Down Expand Up @@ -426,12 +425,11 @@ def estimate(
) from exc

# calculate the probability of measuring '1'
prob = _probabilities_from_sampler_result(
circuit.num_qubits, ret, estimation_problem
)
prob = cast(
float, prob
) # tell MyPy it's a float and not Tuple[int, float ]
prob = 0.0
for bit, probabilities in ret.quasi_dists[0].binary_probabilities().items():
# check if it is a good state
if estimation_problem.is_good_state(bit):
prob += probabilities

a_confidence_interval = [prob, prob] # type: list[float]
a_intervals.append(a_confidence_interval)
Expand All @@ -444,8 +442,8 @@ def estimate(
break

counts = {
np.binary_repr(k, circuit.num_qubits): round(v * shots)
for k, v in ret.quasi_dists[0].items()
k: round(v * shots)
for k, v in ret.quasi_dists[0].binary_probabilities().items()
}

# calculate the probability of measuring '1', 'prob' is a_i in the paper
Expand Down
13 changes: 5 additions & 8 deletions qiskit/algorithms/amplitude_estimators/mlae.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,19 +364,16 @@ def estimate(
result.circuit_results = []
shots = ret.metadata[0].get("shots")
if shots is None:
for i, quasi_dist in enumerate(ret.quasi_dists):
circuit_result = {
np.binary_repr(k, circuits[i].num_qubits): v
for k, v in quasi_dist.items()
}
for quasi_dist in ret.quasi_dists:
circuit_result = quasi_dist.binary_probabilities()
result.circuit_results.append(circuit_result)
shots = 1
else:
# get counts and construct MLE input
for circuit in circuits:
for quasi_dist in ret.quasi_dists:
counts = {
np.binary_repr(k, circuit.num_qubits): round(v * shots)
for k, v in ret.quasi_dists[0].items()
k: round(v * shots)
for k, v in quasi_dist.binary_probabilities().items()
}
result.circuit_results.append(counts)

Expand Down
26 changes: 13 additions & 13 deletions test/python/algorithms/test_amplitude_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,21 @@ def test_qasm(self, prob, shots, qae, expect):

@idata(
[
[0.2, 100, AmplitudeEstimation(4), {"estimation": 0.500000, "mle": 0.562783}],
[0.2, 100, AmplitudeEstimation(4), {"estimation": 0.14644, "mle": 0.198783}],
[0.0, 1000, AmplitudeEstimation(2), {"estimation": 0.0, "mle": 0.0}],
[
0.2,
100,
MaximumLikelihoodAmplitudeEstimation([0, 1, 2, 4, 8]),
{"estimation": 0.474790},
{"estimation": 0.200308},
],
[0.8, 10, IterativeAmplitudeEstimation(0.1, 0.05), {"estimation": 0.811711}],
[0.2, 1000, FasterAmplitudeEstimation(0.1, 3, rescale=False), {"estimation": 0.199073}],
[0.2, 1000, FasterAmplitudeEstimation(0.1, 3, rescale=False), {"estimation": 0.198640}],
[
0.12,
100,
FasterAmplitudeEstimation(0.01, 3, rescale=False),
{"estimation": 0.120016},
{"estimation": 0.120017},
],
]
)
Expand Down Expand Up @@ -422,10 +422,10 @@ def test_statevector(self, n, qae, expect):

@idata(
[
[2, AmplitudeEstimation(2), {"estimation": 0.5, "mle": 0.270290}],
[4, MaximumLikelihoodAmplitudeEstimation(4), {"estimation": 0.0}],
[3, IterativeAmplitudeEstimation(0.1, 0.1), {"estimation": 0.0}],
[3, FasterAmplitudeEstimation(0.01, 1), {"estimation": 0.017687}],
[2, AmplitudeEstimation(2), {"estimation": 0.5, "mle": 0.2702}],
[4, MaximumLikelihoodAmplitudeEstimation(4), {"estimation": 0.2725}],
[3, IterativeAmplitudeEstimation(0.1, 0.1), {"estimation": 0.2721}],
[3, FasterAmplitudeEstimation(0.01, 1), {"estimation": 0.2792}],
]
)
@unpack
Expand All @@ -444,7 +444,7 @@ def test_sampler(self, n, qae, expect):

@idata(
[
[4, 10, AmplitudeEstimation(2), {"estimation": 0.5, "mle": 0.333333}],
[4, 100, AmplitudeEstimation(2), {"estimation": 0.5, "mle": 0.281196}],
[3, 10, MaximumLikelihoodAmplitudeEstimation(2), {"estimation": 0.256878}],
[3, 1000, IterativeAmplitudeEstimation(0.01, 0.01), {"estimation": 0.271790}],
[3, 1000, FasterAmplitudeEstimation(0.1, 4), {"estimation": 0.274168}],
Expand All @@ -465,10 +465,10 @@ def test_qasm(self, n, shots, qae, expect):

@idata(
[
[4, 10, AmplitudeEstimation(2), {"estimation": 0.0, "mle": 0.0}],
[3, 10, MaximumLikelihoodAmplitudeEstimation(2), {"estimation": 0.0}],
[3, 1000, IterativeAmplitudeEstimation(0.01, 0.01), {"estimation": 0.0}],
[3, 1000, FasterAmplitudeEstimation(0.1, 4), {"estimation": 0.000551}],
[4, 1000, AmplitudeEstimation(2), {"estimation": 0.5, "mle": 0.2636}],
[3, 10, MaximumLikelihoodAmplitudeEstimation(2), {"estimation": 0.2904}],
[3, 1000, IterativeAmplitudeEstimation(0.01, 0.01), {"estimation": 0.2706}],
[3, 1000, FasterAmplitudeEstimation(0.1, 4), {"estimation": 0.2764}],
]
)
@unpack
Expand Down

0 comments on commit 66c884e

Please sign in to comment.