Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add termination callback to SPSA #6839

Merged
merged 44 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7601679
add termination callback to SPSA
eendebakpt Jul 30, 2021
f5c204c
add release notes
eendebakpt Jul 30, 2021
c0b92bc
fix linting
eendebakpt Jul 30, 2021
540cc46
update QNSPSA for termination_callback
eendebakpt Jul 30, 2021
72fae89
fix linting
eendebakpt Jul 30, 2021
f9b6ffd
fix linting
eendebakpt Jul 30, 2021
f0c1e04
Merge branch 'main' into feat/spsa_termination
peendebak Jul 30, 2021
d1db10c
fix typo
eendebakpt Jul 31, 2021
a86f720
Merge branch 'main' into feat/spsa_termination
peendebak Aug 1, 2021
f418df0
add termination callback to settings
eendebakpt Aug 1, 2021
47c4414
Merge branch 'feat/spsa_termination' of github.com:peendebak/qiskit-t…
eendebakpt Aug 1, 2021
9448d4b
fix tests
eendebakpt Aug 1, 2021
8c18b3a
add example to release notes
eendebakpt Aug 2, 2021
a569c5f
add example
eendebakpt Aug 2, 2021
fdb506f
fix release notes
eendebakpt Aug 2, 2021
7878123
address review comments
eendebakpt Aug 3, 2021
c709689
whitespae
eendebakpt Aug 3, 2021
fd1d576
address review comments
eendebakpt Aug 3, 2021
c178d9b
address review comments
eendebakpt Aug 3, 2021
50c02bd
Merge branch 'main' into feat/spsa_termination
peendebak Aug 3, 2021
4afc5b9
rename to termination_checker; pass Optimizer as argument
eendebakpt Aug 5, 2021
718cd43
complete renames
eendebakpt Aug 5, 2021
f91395b
Merge branch 'main' into feat/spsa_termination
peendebak Aug 5, 2021
b3367d7
Merge branch 'main' into feat/spsa_termination
peendebak Aug 17, 2021
ebdb6b2
Merge branch 'main' into feat/spsa_termination
peendebak Aug 18, 2021
149d074
Merge branch 'main' into feat/spsa_termination
peendebak Aug 20, 2021
96b2ec7
Merge branch 'main' into feat/spsa_termination
peendebak Aug 20, 2021
2ba0381
Merge branch 'main' into feat/spsa_termination
peendebak Aug 22, 2021
97cdbac
Merge branch 'main' into feat/spsa_termination
peendebak Aug 25, 2021
36a6246
Update qiskit/algorithms/optimizers/spsa.py
peendebak Aug 26, 2021
8f24135
address review comments
eendebakpt Aug 26, 2021
6bf0f61
fix pylint
eendebakpt Aug 26, 2021
9d72aec
Merge branch 'main' into feat/spsa_termination
peendebak Aug 26, 2021
019300d
trigger build
eendebakpt Aug 27, 2021
72dfca4
Merge branch 'feat/spsa_termination' of github.com:peendebak/qiskit-t…
eendebakpt Aug 27, 2021
fade083
fix signature of callback
eendebakpt Sep 1, 2021
2ec3f15
Merge branch 'main' into feat/spsa_termination
peendebak Sep 1, 2021
0c94e9a
fix pylint
eendebakpt Sep 1, 2021
8873f0e
fix pylint
eendebakpt Sep 1, 2021
1e9e910
Update qiskit/algorithms/optimizers/spsa.py
peendebak Sep 3, 2021
87a363b
update name of callback signature
eendebakpt Sep 3, 2021
02c63e3
Merge branch 'feat/spsa_termination' of github.com:peendebak/qiskit-t…
eendebakpt Sep 3, 2021
c54988a
fix pylint
eendebakpt Sep 6, 2021
6a13903
Merge branch 'main' into feat/spsa_termination
mergify[bot] Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions qiskit/algorithms/optimizers/qnspsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,11 @@
from qiskit.opflow import StateFn, CircuitSampler, ExpectationBase
from qiskit.utils import QuantumInstance

from .spsa import SPSA, _batch_evaluate
from .spsa import SPSA, CALLBACK, TERMINATIONCALLBACK, _batch_evaluate

# the function to compute the fidelity
FIDELITY = Callable[[np.ndarray, np.ndarray], float]

# parameters, loss, stepsize, number of function evaluations, accepted
CALLBACK = Callable[[np.ndarray, float, float, int, bool], None]


class QNSPSA(SPSA):
r"""The Quantum Natural SPSA (QN-SPSA) optimizer.
Expand Down Expand Up @@ -95,6 +92,7 @@ def __init__(
lse_solver: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None,
initial_hessian: Optional[np.ndarray] = None,
callback: Optional[CALLBACK] = None,
termination_callback: Optional[TERMINATIONCALLBACK] = None,
) -> None:
r"""
Args:
Expand Down Expand Up @@ -141,6 +139,9 @@ def __init__(
callback: A callback function passed information in each iteration step. The
information is, in this order: the parameters, the function value, the number
of function evaluations, the stepsize, whether the step was accepted.
termination_callback: A callback function executed at the end of each iteration step. The
arguments are, in this order: current parameters, estimate of the objective
If the callback returns True, the optimization is aborted
peendebak marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(
maxiter,
Expand All @@ -158,6 +159,7 @@ def __init__(
regularization=regularization,
perturbation_dims=perturbation_dims,
initial_hessian=initial_hessian,
termination_callback=termination_callback,
)

self.fidelity = fidelity
Expand Down Expand Up @@ -187,7 +189,7 @@ def _point_sample(self, loss, x, eps, delta1, delta2):
# -0.5 factor comes from the fact that we need -0.5 * fidelity
hessian_estimate = -0.5 * diff * (rank_one + rank_one.T) / 2

return gradient_estimate, hessian_estimate
return np.mean(loss_values), gradient_estimate, hessian_estimate

@property
def settings(self) -> Dict[str, Any]:
Expand Down
73 changes: 65 additions & 8 deletions qiskit/algorithms/optimizers/spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
This implementation allows both, standard first-order as well as second-order SPSA.
"""

from typing import Iterator, Optional, Union, Callable, Tuple, Dict
from typing import Iterator, Optional, Union, Callable, Tuple, Dict, Any
import logging
import warnings
from time import time
Expand All @@ -30,6 +30,7 @@

# number of function evaluations, parameters, loss, stepsize, accepted
CALLBACK = Callable[[int, np.ndarray, float, float, bool], None]
TERMINATIONCALLBACK = Callable[[np.ndarray, float], bool]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
lse_solver: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None,
initial_hessian: Optional[np.ndarray] = None,
callback: Optional[CALLBACK] = None,
termination_callback: Optional[TERMINATIONCALLBACK] = None,
) -> None:
r"""
Args:
Expand Down Expand Up @@ -190,17 +192,57 @@ def __init__(
callback: A callback function passed information in each iteration step. The
information is, in this order: the number of function evaluations, the parameters,
the function value, the stepsize, whether the step was accepted.
termination_callback: A callback function executed at the end of each iteration step. The
arguments are, in this order: current parameters, estimate of the objective
peendebak marked this conversation as resolved.
Show resolved Hide resolved
If the callback returns True, the optimization is aborted.
To prevent additional evaluations of the objective method, objective is estimated by
taking the mean of the objective evaluations used in the estimate of the gradient.

Raises:
ValueError: If ``learning_rate`` or ``perturbation`` is an array with less elements
than the number of iterations.

Example:
.. jupyter-execute::
peendebak marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np
from qiskit.algorithms.optimizers import SPSA

def objective(x):
return np.linalg.norm(x) + .04*np.random.rand(1)

class TerminationCallback:

def __init__(self, N : int):
self.N = N
self.values = []

def __call__(self, parameters, value) -> bool:
self.values.append(value)

if len(self.values) > self.N:
last_values = self.values[-self.N:]
pp = np.polyfit(range(self.N), last_values, 1)
slope = pp[0] / self.N

if slope > 0:
return True
return False

spsa = SPSA(maxiter=200, termination_callback=TerminationCallback(10))
parameters, value, niter = spsa.optimize(2, objective, initial_point=[0.5, 0.5])
print(f'SPSA completed after {niter} iterations')



"""
super().__init__()

# general optimizer arguments
self.maxiter = maxiter
self.trust_region = trust_region
self.callback = callback
self.termination_callback = termination_callback

# if learning rate and perturbation are arrays, check they are sufficiently long
for attr, name in zip([learning_rate, perturbation], ["learning_rate", "perturbation"]):
Expand Down Expand Up @@ -323,7 +365,8 @@ def estimate_stddev(
return np.std(losses)

@property
def settings(self):
def settings(self) -> Dict[str, Any]:
"""Return dictonary containing the settings of the optimizer"""
peendebak marked this conversation as resolved.
Show resolved Hide resolved
# if learning rate or perturbation are custom iterators expand them
if callable(self.learning_rate):
iterator = self.learning_rate()
Expand Down Expand Up @@ -352,6 +395,7 @@ def settings(self):
"lse_solver": self.lse_solver,
"initial_hessian": self.initial_hessian,
"callback": self.callback,
"termination_callback": self.termination_callback,
}

def _point_sample(self, loss, x, eps, delta1, delta2):
Expand Down Expand Up @@ -379,11 +423,12 @@ def _point_sample(self, loss, x, eps, delta1, delta2):
rank_one = np.outer(delta1, delta2)
hessian_sample = diff * (rank_one + rank_one.T) / 2

return gradient_sample, hessian_sample
return np.mean(values), gradient_sample, hessian_sample

def _point_estimate(self, loss, x, eps, num_samples):
"""The gradient estimate at point x."""
# set up variables to store averages
value_estimate = 0
gradient_estimate = np.zeros(x.size)
hessian_estimate = np.zeros((x.size, x.size))

Expand All @@ -403,13 +448,20 @@ def _point_estimate(self, loss, x, eps, num_samples):
delta1 = deltas1[i]
delta2 = deltas2[i] if self.second_order else None

gradient_sample, hessian_sample = self._point_sample(loss, x, eps, delta1, delta2)
value_sample, gradient_sample, hessian_sample = self._point_sample(
loss, x, eps, delta1, delta2
)
value_estimate += value_sample
gradient_estimate += gradient_sample

if self.second_order:
hessian_estimate += hessian_sample

return gradient_estimate / num_samples, hessian_estimate / num_samples
return (
value_estimate / num_samples,
gradient_estimate / num_samples,
hessian_estimate / num_samples,
)

def _compute_update(self, loss, x, k, eps, lse_solver):
# compute the perturbations
Expand All @@ -419,7 +471,7 @@ def _compute_update(self, loss, x, k, eps, lse_solver):
num_samples = self.resamplings

# accumulate the number of samples
gradient, hessian = self._point_estimate(loss, x, eps, num_samples)
value, gradient, hessian = self._point_estimate(loss, x, eps, num_samples)

# precondition gradient with inverse Hessian, if specified
if self.second_order:
Expand All @@ -432,7 +484,7 @@ def _compute_update(self, loss, x, k, eps, lse_solver):
# solve for the gradient update
gradient = np.real(lse_solver(spd_hessian, gradient))

return gradient
return value, gradient

def _minimize(self, loss, initial_point):
# ensure learning rate and perturbation are correctly set: either none or both
Expand Down Expand Up @@ -481,7 +533,7 @@ def _minimize(self, loss, initial_point):
for k in range(1, self.maxiter + 1):
iteration_start = time()
# compute update
update = self._compute_update(loss, x, k, next(eps), lse_solver)
f_estimate, update = self._compute_update(loss, x, k, next(eps), lse_solver)

# trust region
if self.trust_region:
Expand Down Expand Up @@ -544,6 +596,11 @@ def _minimize(self, loss, initial_point):
if len(last_steps) > self.last_avg:
last_steps.popleft()

if self.termination_callback is not None:
if self.termination_callback(x, f_estimate):
logger.info("aborting optimization at {k}/{self.maxiter")
break

logger.info("SPSA finished in %s", time() - start)
logger.info("=" * 30)

Expand Down
41 changes: 41 additions & 0 deletions releasenotes/notes/SPSA-termination-callback-a1ec14892f553982.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
features:
- |
Add `termination_callback` argument to :class:`qiskit.algorithms.optimizers.spsa.SPSA` optimizer.
This allows the user to implement a custom termination criterium.
peendebak marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

import numpy as np
from qiskit.algorithms.optimizers import SPSA

def objective(x):
return np.linalg.norm(x) + .04*np.random.rand(1)

class TerminationCallback:

def __init__(self, N : int):
""" Callback to terminate optmization when the average decrease over the last N data points is smaller than the specified tolerance """
peendebak marked this conversation as resolved.
Show resolved Hide resolved
self.N = N
self.values = []

def __call__(self, parameters, value) -> bool:
"""
Returns:
True if the optimization loop should be aborted
"""
self.values.append(value)

if len(self.values) > self.N:
last_values = self.values[-self.N:]
pp = np.polyfit(range(self.N), last_values, 1)
slope = pp[0] / self.N

if slope > 0:
return True
return False

maxiter = 400
spsa = SPSA(maxiter=maxiter, termination_callback=TerminationCallback(10))
parameters, value, niter = spsa.optimize(2, objective, initial_point=np.array([0.5, 0.5]))
print(f'SPSA completed after {niter} iterations')
2 changes: 2 additions & 0 deletions test/python/algorithms/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def test_spsa(self):
"lse_solver": None,
"hessian_delay": 0,
"callback": None,
"termination_callback": None,
}
spsa = SPSA(**options)

Expand Down Expand Up @@ -327,6 +328,7 @@ def test_qnspsa(self):
"lse_solver": None,
"initial_hessian": None,
"callback": None,
"termination_callback": None,
"hessian_delay": 0,
}
spsa = QNSPSA(**options)
Expand Down
25 changes: 25 additions & 0 deletions test/python/algorithms/optimizers/test_spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,31 @@ def objective(x):

self.assertAlmostEqual(np.linalg.norm(result), 2, places=2)

def test_termination_callback(self):
"""Test the termination_callback"""

def objective(x):
return np.linalg.norm(x) + np.random.rand(1)

class TerminationCallback:
"""Example termination callback"""

def __init__(self):
self.values = []

def __call__(self, parameters, value) -> bool:
self.values.append(value)

if len(self.values) > 10:
return True
return False

maxiter = 400
spsa = SPSA(maxiter=maxiter, termination_callback=TerminationCallback())
_, _, niter = spsa.optimize(2, objective, initial_point=[0.5, 0.5])

self.assertLess(niter, maxiter)

def test_callback(self):
"""Test using the callback."""

Expand Down