Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added callback function support for adam-amsgrad optimizer. #869

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions qiskit_machine_learning/optimizers/adam_amsgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
from .optimizer import Optimizer, OptimizerSupportLevel, OptimizerResult, POINT

CALLBACK = Callable[[int, POINT, float], None]

# pylint: disable=invalid-name


Expand Down Expand Up @@ -69,6 +71,7 @@ def __init__(
eps: float = 1e-10,
amsgrad: bool = False,
snapshot_dir: str | None = None,
callback: CALLBACK | None = None,
) -> None:
"""
Args:
Expand All @@ -83,8 +86,11 @@ def __init__(
amsgrad: True to use AMSGRAD, False if not
snapshot_dir: If not None save the optimizer's parameter
after every step to the given directory
callback: A callback function passed information in each iteration step.
The information is, in this order: current time step, the parameters, the function value.
"""
super().__init__()
self.callback = callback
for k, v in list(locals().items()):
if k in self._OPTIONS:
self._options[k] = v
Expand Down Expand Up @@ -233,6 +239,9 @@ def minimize(
if self._snapshot_dir is not None:
self.save_params(self._snapshot_dir)

if self.callback is not None:
self.callback(self._t, params_new, fun(params_new))

# check termination
if np.linalg.norm(params - params_new) < self._tol:
break
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
features:
- |
The :class:`~qiskit_machine_learning.optimizers.ADAM` class now supports a callback function.
This feature allows users to pass a custom callback function that will be called with
information at each iteration step during the optimization process.
The information passed to the callback includes the current time step, the parameters,
and the function value. The callback function should be of the type
`Callable[[int, Union[float, np.ndarray], float], None]`.

Example of a callback function:

.. code-block:: python

def callback(iteration:int, weights:np.ndarray, loss:float):
...
acc = calculate_accuracy(weights)
print(acc)
print(loss)
...
18 changes: 18 additions & 0 deletions test/optimizers/test_adam_amsgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ def test_settings(self):
self.assertEqual(settings["amsgrad"], False)
self.assertEqual(settings["snapshot_dir"], None)

def test_callback(self):
"""Test using the callback."""

history = {"ite": [], "weights": [], "fvals": []}

def callback(n_t, weight, fval):
history["ite"].append(n_t)
history["weights"].append(weight)
history["fvals"].append(fval)

adam = ADAM(maxiter=100, tol=1e-6, lr=1e-1, callback=callback)
adam.minimize(self.quadratic_objective, self.initial_point)

expected_types = [int, np.ndarray, float]
for i, (key, values) in enumerate(history.items()):
self.assertTrue(all(isinstance(value, expected_types[i]) for value in values))
self.assertEqual(len(history[key]), 100)


if __name__ == "__main__":
unittest.main()
Loading