Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError in TransitFit.optimize_transit_params, maybe in jaxopt #6

Closed
HajimeKawahara opened this issue Aug 20, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@HajimeKawahara
Copy link
Collaborator

While creating the unit code for the transit module (unittest_3 branch), I encountered the following error. The same error also occurs at the tf.optimize_transit_params section when running examples/transit.ipynb.

environment: python==3.10.9, jaxopt==0.8.2, jax==0.4.31

  • unittst_3 branch :: ~/jkepler/tests/unittests/transit(unittest_3)>python transit_test.py
hirochan:~/jkepler/tests/unittests/transit(unittest_3)>python transit_test.py
# initial objective function: -24523.9

# optimizing t0 and period...
/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py:343: OptimizeWarning: Unknown solver options: maxiter
  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
Traceback (most recent call last):
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 223, in __getattr__
    return self[name]
KeyError: 'njev'

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/kawahara/jkepler/tests/unittests/transit/transit_test.py", line 50, in <module>
    test_compute_prediction()
  File "/home/kawahara/jkepler/tests/unittests/transit/transit_test.py", line 41, in test_compute_prediction
    popt = tf.optimize_transit_params(flux, error, t0, period, ecc, omega, b, rstar, rp_over_r, fit_ttvs=False)
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jkepler-0.0.1-py3.10.egg/jkepler/transit/transitfit.py", line 203, in optimize_transit_params
    res = solver.run(p_init, bounds=bounds)
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun
    return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat
    return solver_fun(*args, **kwargs)
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py", line 457, in run
    return self._run(init_params, bounds, *args, **kwargs)
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py", line 373, in _run
    num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE),
  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 225, in __getattr__
    raise AttributeError(name) from e
AttributeError: njev. Did you mean: 'nfev'?
@HajimeKawahara HajimeKawahara added the bug Something isn't working label Aug 20, 2024
@kemasuda
Copy link
Owner

I ran transit_test.py and couldn't reproduce this error in environment: python==3.10.12, jaxopt==0.8.3, jax==0.4.23, scipy=1.12.0

  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 223, in __getattr__
    return self[name]
KeyError: 'njev'

Could this be a scipy issue?

@HajimeKawahara
Copy link
Collaborator Author

HajimeKawahara commented Aug 20, 2024

I see, thanks. scipy==1.11.2
Also, scipy==1.14.0 ended up with the same error.

@kemasuda
Copy link
Owner

The code worked with scipy==1.14.0 for me. Maybe this is fixed in jaxopt==0.8.3 google/jaxopt#542 so can you try this version? This commit was to fix this error: google/jaxopt#536

@HajimeKawahara
Copy link
Collaborator Author

@kemasuda Thanks! jaxopt==0.8.3 solves this issue. I will add jaxopt>=0.8.3 in the requirement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants