diff --git a/qiskit_algorithms/optimizers/spsa.py b/qiskit_algorithms/optimizers/spsa.py index 707409ef..10d81dfd 100644 --- a/qiskit_algorithms/optimizers/spsa.py +++ b/qiskit_algorithms/optimizers/spsa.py @@ -727,7 +727,7 @@ def _batch_evaluate(function, points, max_evals_grouped): num_batches += 1 # split the points - batched_points = np.split(np.asarray(points), num_batches) + batched_points = np.array_split(np.asarray(points), num_batches) results = [] for batch in batched_points: diff --git a/test/optimizers/test_spsa.py b/test/optimizers/test_spsa.py index 55e4b7b3..5f31d69d 100644 --- a/test/optimizers/test_spsa.py +++ b/test/optimizers/test_spsa.py @@ -181,3 +181,17 @@ def callback(nfev, point, fval, update, accepted): for i, (key, values) in enumerate(history.items()): self.assertTrue(all(isinstance(value, expected_types[i]) for value in values)) self.assertEqual(len(history[key]), maxiter) + + @data(1, 2, 3, 4) + def test_estimate_stddev(self, max_evals_grouped): + """Test the estimate_stddev + See https://github.com/Qiskit/qiskit-nature/issues/797""" + + def objective(x): + if len(x.shape) == 2: + return np.array([sum(x_i) for x_i in x]) + return sum(x) + + point = np.ones(5) + result = SPSA.estimate_stddev(objective, point, avg=10, max_evals_grouped=max_evals_grouped) + self.assertAlmostEqual(result, 0)