Skip to content

Commit

Permalink
Use np.array_split instead of np.split to support uneven splits in sp…
Browse files Browse the repository at this point in the history
…sa._batch_evaluate (#8634)

* Use np.array_split instead of np.split to support uneven splits in spsa._batch_evaluate

* better testing

Co-authored-by: Julien Gacon <gaconju@gmail.com>

* Update test/python/algorithms/optimizers/test_spsa.py

Co-authored-by: Julien Gacon <gaconju@gmail.com>

* len of a boolean

* Apply suggestions from code review

* Fix Sphinx ref

Co-authored-by: Julien Gacon <gaconju@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 31, 2022
1 parent 219032d commit 69ad6c6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion qiskit/algorithms/optimizers/spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def _batch_evaluate(function, points, max_evals_grouped):
num_batches += 1

# split the points
batched_points = np.split(np.asarray(points), num_batches)
batched_points = np.array_split(np.asarray(points), num_batches)

results = []
for batch in batched_points:
Expand Down
6 changes: 6 additions & 0 deletions releasenotes/notes/qiskit-nature-797-8f1b0975309b8756.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
When the class :class:`~.SPSA` was using `np.split` (from NumPy) for splitting the jobs in even batches,
resulting in an exception if a perfectly even split was not possible. Now, it uses `np.array_split`, which is safer
for these cases.
14 changes: 14 additions & 0 deletions test/python/algorithms/optimizers/test_spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,17 @@ def callback(nfev, point, fval, update, accepted):
for i, (key, values) in enumerate(history.items()):
self.assertTrue(all(isinstance(value, expected_types[i]) for value in values))
self.assertEqual(len(history[key]), maxiter)

@data(1, 2, 3, 4)
def test_estimate_stddev(self, max_evals_grouped):
"""Test the estimate_stddev
See https://github.com/Qiskit/qiskit-nature/issues/797"""

def objective(x):
if len(x.shape) == 2:
return np.array([sum(x_i) for x_i in x])
return sum(x)

point = np.ones(5)
result = SPSA.estimate_stddev(objective, point, avg=10, max_evals_grouped=max_evals_grouped)
self.assertAlmostEqual(result, 0)

0 comments on commit 69ad6c6

Please sign in to comment.