Skip to content

Commit f3a0ddc

Browse files
Cryorismtreinishmergify[bot]
authored
Fix default batching in variational algorithms (Qiskit#9038)
* Fix default batching in variational algorithms * fix test * reduce batching to only SPSA * fix tests * Apply suggestions from code review Co-authored-by: Matthew Treinish <mtreinish@kortar.org> Co-authored-by: Matthew Treinish <mtreinish@kortar.org> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 8d7d300 commit f3a0ddc

File tree

7 files changed

+96
-4
lines changed

7 files changed

+96
-4
lines changed

qiskit/algorithms/eigensolvers/vqd.py

+11
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from ..exceptions import AlgorithmError
3939
from ..observables_evaluator import estimate_observables
4040

41+
# private function as we expect this to be updated in the next release
42+
from ..utils.set_batching import _set_default_batchsize
43+
4144
logger = logging.getLogger(__name__)
4245

4346

@@ -264,10 +267,18 @@ def compute_eigenvalues(
264267
fun=energy_evaluation, x0=initial_point, bounds=bounds
265268
)
266269
else:
270+
# we always want to submit as many estimations per job as possible for minimal
271+
# overhead on the hardware
272+
was_updated = _set_default_batchsize(self.optimizer)
273+
267274
opt_result = self.optimizer.minimize(
268275
fun=energy_evaluation, x0=initial_point, bounds=bounds
269276
)
270277

278+
# reset to original value
279+
if was_updated:
280+
self.optimizer.set_max_evals_grouped(None)
281+
271282
eval_time = time() - start_time
272283

273284
self._update_vqd_result(result, opt_result, eval_time, self.ansatz.copy())

qiskit/algorithms/minimum_eigensolvers/sampling_vqe.py

+11
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from ..observables_evaluator import estimate_observables
4040
from ..utils import validate_initial_point, validate_bounds
4141

42+
# private function as we expect this to be updated in the next released
43+
from ..utils.set_batching import _set_default_batchsize
44+
4245

4346
logger = logging.getLogger(__name__)
4447

@@ -208,10 +211,18 @@ def compute_minimum_eigenvalue(
208211
# pylint: disable=not-callable
209212
optimizer_result = self.optimizer(fun=evaluate_energy, x0=initial_point, bounds=bounds)
210213
else:
214+
# we always want to submit as many estimations per job as possible for minimal
215+
# overhead on the hardware
216+
was_updated = _set_default_batchsize(self.optimizer)
217+
211218
optimizer_result = self.optimizer.minimize(
212219
fun=evaluate_energy, x0=initial_point, bounds=bounds
213220
)
214221

222+
# reset to original value
223+
if was_updated:
224+
self.optimizer.set_max_evals_grouped(None)
225+
215226
optimizer_time = time() - start_time
216227

217228
logger.info(

qiskit/algorithms/minimum_eigensolvers/vqe.py

+11
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
from ..observables_evaluator import estimate_observables
3636
from ..utils import validate_initial_point, validate_bounds
3737

38+
# private function as we expect this to be updated in the next released
39+
from ..utils.set_batching import _set_default_batchsize
40+
3841
logger = logging.getLogger(__name__)
3942

4043

@@ -181,10 +184,18 @@ def compute_minimum_eigenvalue(
181184
fun=evaluate_energy, x0=initial_point, jac=evaluate_gradient, bounds=bounds
182185
)
183186
else:
187+
# we always want to submit as many estimations per job as possible for minimal
188+
# overhead on the hardware
189+
was_updated = _set_default_batchsize(self.optimizer)
190+
184191
optimizer_result = self.optimizer.minimize(
185192
fun=evaluate_energy, x0=initial_point, jac=evaluate_gradient, bounds=bounds
186193
)
187194

195+
# reset to original value
196+
if was_updated:
197+
self.optimizer.set_max_evals_grouped(None)
198+
188199
optimizer_time = time() - start_time
189200

190201
logger.info(

qiskit/algorithms/optimizers/optimizer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def __init__(self):
180180
self._bounds_support_level = self.get_support_level()["bounds"]
181181
self._initial_point_support_level = self.get_support_level()["initial_point"]
182182
self._options = {}
183-
self._max_evals_grouped = 1
183+
self._max_evals_grouped = None
184184

185185
@abstractmethod
186186
def get_support_level(self):
@@ -205,7 +205,7 @@ def set_options(self, **kwargs):
205205

206206
# pylint: disable=invalid-name
207207
@staticmethod
208-
def gradient_num_diff(x_center, f, epsilon, max_evals_grouped=1):
208+
def gradient_num_diff(x_center, f, epsilon, max_evals_grouped=None):
209209
"""
210210
We compute the gradient with the numeric differentiation in the parallel way,
211211
around the point x_center.
@@ -214,11 +214,14 @@ def gradient_num_diff(x_center, f, epsilon, max_evals_grouped=1):
214214
x_center (ndarray): point around which we compute the gradient
215215
f (func): the function of which the gradient is to be computed.
216216
epsilon (float): the epsilon used in the numeric differentiation.
217-
max_evals_grouped (int): max evals grouped
217+
max_evals_grouped (int): max evals grouped, defaults to 1 (i.e. no batching).
218218
Returns:
219219
grad: the gradient computed
220220
221221
"""
222+
if max_evals_grouped is None: # no batching by default
223+
max_evals_grouped = 1
224+
222225
forig = f(*((x_center,)))
223226
grad = []
224227
ei = np.zeros((len(x_center),), float)

qiskit/algorithms/optimizers/spsa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def _batch_evaluate(function, points, max_evals_grouped, unpack_points=False):
719719
"""
720720

721721
# if the function cannot handle lists of points as input, cover this case immediately
722-
if max_evals_grouped == 1:
722+
if max_evals_grouped is None or max_evals_grouped == 1:
723723
# support functions with multiple arguments where the points are given in a tuple
724724
return [
725725
function(*point) if isinstance(point, tuple) else function(point) for point in points
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2022.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
13+
"""Set default batch sizes for the optimizers."""
14+
15+
from qiskit.algorithms.optimizers import Optimizer, SPSA
16+
17+
18+
def _set_default_batchsize(optimizer: Optimizer) -> bool:
19+
"""Set the default batchsize, if None is set and return whether it was updated or not."""
20+
if isinstance(optimizer, SPSA):
21+
updated = optimizer._max_evals_grouped is None
22+
if updated:
23+
optimizer.set_max_evals_grouped(50)
24+
else: # we only set a batchsize for SPSA
25+
updated = False
26+
27+
return updated

test/python/algorithms/minimum_eigensolvers/test_vqe.py

+29
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,35 @@ def run_check():
300300
vqe.optimizer = L_BFGS_B()
301301
run_check()
302302

303+
def test_default_batch_evaluation_on_spsa(self):
304+
"""Test the default batching works."""
305+
ansatz = TwoLocal(2, rotation_blocks=["ry", "rz"], entanglement_blocks="cz")
306+
307+
wrapped_estimator = Estimator()
308+
inner_estimator = Estimator()
309+
310+
callcount = {"estimator": 0}
311+
312+
def wrapped_estimator_run(*args, **kwargs):
313+
kwargs["callcount"]["estimator"] += 1
314+
return inner_estimator.run(*args, **kwargs)
315+
316+
wrapped_estimator.run = partial(wrapped_estimator_run, callcount=callcount)
317+
318+
spsa = SPSA(maxiter=5)
319+
320+
vqe = VQE(wrapped_estimator, ansatz, spsa)
321+
_ = vqe.compute_minimum_eigenvalue(Pauli("ZZ"))
322+
323+
# 1 calibration + 5 loss + 1 return loss
324+
expected_estimator_runs = 1 + 5 + 1
325+
326+
with self.subTest(msg="check callcount"):
327+
self.assertEqual(callcount["estimator"], expected_estimator_runs)
328+
329+
with self.subTest(msg="check reset to original max evals grouped"):
330+
self.assertIsNone(spsa._max_evals_grouped)
331+
303332
def test_batch_evaluate_with_qnspsa(self):
304333
"""Test batch evaluating with QNSPSA works."""
305334
ansatz = TwoLocal(2, rotation_blocks=["ry", "rz"], entanglement_blocks="cz")

0 commit comments

Comments
 (0)