ScipyBoundedMinimize and syntax of params with args in a pytree #478
Replies: 1 comment
-
Sorry for the late reply. Since you edited the question, I assume you managed to resolve your issue by yourself? Most solvers in JAXopt don't work out of the box with scalars (maybe they should?) So typically, ones needs to use an array of size 1 instead. import jax
import jax.numpy as jnp
from jaxopt import ScipyBoundedMinimize
def example_function(w, pytree): #function I want to minimize by optimizing w (while x, y, z are kept to calculated constants)
x, y, z = pytree
return w[0] ** 2 + x * y + jnp.sin(z) #fake function
# Initial Pytree values (excluding w)
initial_pytree = (2.0, 3.0, 4.0)
# Define bounds for the optimization variable as w >= 0.0
bounds = [(jnp.array([0.0]), None)]
# Define a wrapper function to fix the other leaves of the Pytree
def new_function(w, pytree):
return (example_function(w, pytree))**2
# Initial guess for w
w_initial = jnp.array([1.0])
# Create the ScipyBoundedMinimize object
optimizer = ScipyBoundedMinimize(fun=new_function, bounds=bounds)
# Perform the optimization
opt_result = optimizer.run(w_initial, bounds, initial_pytree) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Beta Was this translation helpful? Give feedback.
All reactions