Skip to content

Commit

Permalink
Add complex rosenbrock test
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierronan committed Dec 4, 2024
1 parent 777a8ce commit 8c83a48
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,5 +904,25 @@ def f_real(x):
sol_complex, to_complex(sol_real), atol=tol, rtol=tol
)

def test_lbfgs_complex_rosenbrock(self):
# Taken from previous jax tests
tol = 1e-5
complex_dim = 5

fun_real = _get_problem('rosenbrock')['fun']
init_real = jnp.zeros((2 * complex_dim,), dtype=complex)
expected_real = jnp.ones((2 * complex_dim,), dtype=complex)

def fun(z):
x_real = jnp.concatenate([jnp.real(z), jnp.imag(z)])
return fun_real(x_real)

init = init_real[:complex_dim] + 1.j * init_real[complex_dim:]
expected = expected_real[:complex_dim] + 1.j * expected_real[complex_dim:]

opt = alias.lbfgs()
got, _ = _run_opt(opt, fun, init, maxiter=500, tol=tol)
chex.assert_trees_all_close(got, expected)

if __name__ == '__main__':
absltest.main()

0 comments on commit 8c83a48

Please sign in to comment.