diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 5bc44edf5..8412f6e86 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -904,5 +904,25 @@ def f_real(x): sol_complex, to_complex(sol_real), atol=tol, rtol=tol ) + def test_lbfgs_complex_rosenbrock(self): + # Taken from previous jax tests + tol = 1e-5 + complex_dim = 5 + + fun_real = _get_problem('rosenbrock')['fun'] + init_real = jnp.zeros((2 * complex_dim,), dtype=complex) + expected_real = jnp.ones((2 * complex_dim,), dtype=complex) + + def fun(z): + x_real = jnp.concatenate([jnp.real(z), jnp.imag(z)]) + return fun_real(x_real) + + init = init_real[:complex_dim] + 1.j * init_real[complex_dim:] + expected = expected_real[:complex_dim] + 1.j * expected_real[complex_dim:] + + opt = alias.lbfgs() + got, _ = _run_opt(opt, fun, init, maxiter=500, tol=tol) + chex.assert_trees_all_close(got, expected) + if __name__ == '__main__': absltest.main()