Skip to content

Commit 073d4d8

Browse files
committed
update sample_most_likely
1 parent 8c43a0d commit 073d4d8

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

qiskit_optimization/applications/optimization_application.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# that they have been altered from the originals.
1212

1313
"""An abstract class for optimization application classes."""
14-
from typing import Union, Dict
15-
from collections import OrderedDict
1614
from abc import ABC, abstractmethod
15+
from collections import OrderedDict
16+
from typing import Dict, Union
1717

1818
import numpy as np
19-
2019
from qiskit.opflow import StateFn
20+
from qiskit.quantum_info import Statevector
21+
from qiskit.result import QuasiDistribution
22+
2123
from qiskit_optimization.algorithms import OptimizationResult
2224
from qiskit_optimization.problems.quadratic_program import QuadraticProgram
2325

@@ -59,29 +61,54 @@ def _result_to_x(self, result: Union[OptimizationResult, np.ndarray]) -> np.ndar
5961
return x
6062

6163
@staticmethod
62-
def sample_most_likely(state_vector: Union[np.ndarray, Dict]) -> np.ndarray:
64+
def sample_most_likely(
65+
state_vector: Union[QuasiDistribution, Statevector, np.ndarray, Dict]
66+
) -> np.ndarray:
6367
"""Compute the most likely binary string from state vector.
6468
6569
Args:
66-
state_vector: state vector or counts.
70+
state_vector: state vector or counts or quasi-probabilities.
6771
6872
Returns:
6973
binary string as numpy.ndarray of ints.
74+
75+
Raises:
76+
ValueError: if state_vector is not QuasiDistribution, Statevector,
77+
np.ndarray, or dict.
7078
"""
71-
if isinstance(state_vector, (OrderedDict, dict)):
79+
if isinstance(state_vector, QuasiDistribution):
80+
probabilities = state_vector.binary_probabilities()
81+
binary_string = max(probabilities.items(), key=lambda kv: kv[1])[0]
82+
x = np.asarray([int(y) for y in reversed(list(binary_string))])
83+
return x
84+
elif isinstance(state_vector, Statevector):
85+
probabilities = state_vector.probabilities()
86+
n = state_vector.num_qubits
87+
k = np.argmax(np.abs(probabilities))
88+
x = np.zeros(n)
89+
for i in range(n):
90+
x[i] = k % 2
91+
k >>= 1
92+
return x
93+
elif isinstance(state_vector, (OrderedDict, dict)):
7294
# get the binary string with the largest count
73-
binary_string = sorted(state_vector.items(), key=lambda kv: kv[1])[-1][0]
95+
binary_string = max(state_vector.items(), key=lambda kv: kv[1])[0]
7496
x = np.asarray([int(y) for y in reversed(list(binary_string))])
7597
return x
7698
elif isinstance(state_vector, StateFn):
7799
binary_string = list(state_vector.sample().keys())[0]
78100
x = np.asarray([int(y) for y in reversed(list(binary_string))])
79101
return x
80-
else:
102+
elif isinstance(state_vector, np.ndarray):
81103
n = int(np.log2(state_vector.shape[0]))
82104
k = np.argmax(np.abs(state_vector))
83105
x = np.zeros(n)
84106
for i in range(n):
85107
x[i] = k % 2
86108
k >>= 1
87109
return x
110+
else:
111+
raise ValueError(
112+
"state vector should be QuasiDistribution, Statevector, ndarray, or dict. "
113+
f"But it is {type(state_vector)}."
114+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2022.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
13+
"""Test OptimizationApplication class"""
14+
15+
import unittest
16+
from test.optimization_test_case import QiskitOptimizationTestCase
17+
18+
import numpy as np
19+
from ddt import data, ddt
20+
from qiskit.opflow import StateFn
21+
from qiskit.result import QuasiDistribution
22+
23+
from qiskit_optimization.applications import OptimizationApplication
24+
25+
26+
@ddt
27+
class TestOptimizationApplication(QiskitOptimizationTestCase):
28+
"""Test OptimizationApplication class"""
29+
30+
@data(
31+
np.array([0, 0, 1, 0]),
32+
StateFn([0, 0, 1, 0]),
33+
{"10": 0.8, "01": 0.2},
34+
QuasiDistribution({"10": 0.8, "01": 0.2}),
35+
)
36+
def test_sample_most_likely(self, state_vector):
37+
"""Test sample_most_likely"""
38+
39+
result = OptimizationApplication.sample_most_likely(state_vector)
40+
np.testing.assert_allclose(result, [0, 1])
41+
42+
43+
if __name__ == "__main__":
44+
unittest.main()

0 commit comments

Comments
 (0)