Replies: 2 comments
-
Here is a runnable example that is somewhat robust and generic using import jax
import jax.numpy as jnp
import optax
def optimizer(learning_rate, decay=0.9):
return optax.chain(
optax.scale_by_stddev(decay=decay),
optax.scale(-learning_rate)
)
params = {'a': jnp.ones((3, 3)), 'b': {'c': jnp.ones((4, 4)), 'd': jnp.ones((5, 5))}}
tx = optimizer(0.1, decay=0.9)
opt_state = tx.init(params)
# example "model surgery"
new_params = jax.tree_map(lambda x: x, params) # params we will do surgery on
opt_treedef = jax.tree_structure(tx.init(jnp.ones(1))) # only opt_state structure
param_treedef = jax.tree_structure(params) # only param structure
transposed_state = jax.tree_transpose(opt_treedef, param_treedef, opt_state)
# actual surgery can be symmetric between the params and opt_state, so it isn't brittle
new_params['d'] = new_params['b'].pop('d')
transposed_state['d'] = transposed_state['b'].pop('d')
new_param_treedef = jax.tree_structure(new_params)
new_opt_state = jax.tree_transpose(new_param_treedef, opt_treedef, transposed_state)
# test updating with new structure
updates = jax.tree_map(jnp.zeros_like, params)
new_updates = jax.tree_map(jnp.zeros_like, new_params)
tx.update(updates, opt_state, params)
tx.update(new_updates, new_opt_state, new_params) Unfortunately, this does not work with Any ideas on how to account for the |
Beta Was this translation helpful? Give feedback.
-
Here's another simple idea how this could be done. This solution works with arbitrary tree structures including Note: below code is very inefficient because it calls import jax
import jax.numpy as jnp
import optax
params = {'a': jnp.ones((3, 3)), 'b': {'c': jnp.ones((4, 4)), 'd': jnp.ones((5, 5))}}
tx = optax.adam(0.1)
opt_state = tx.init(params)
opt_state def surgery(d):
ret = jax.tree_map(lambda x: x, d)
ret['d'] = ret['b'].pop('d')
return ret
def leaf_or_treedef(target_treedef):
def is_leaf(x):
treedef = jax.tree_structure(x)
return jax.treedef_is_leaf(treedef) or treedef == target_treedef
return is_leaf
param_treedef = jax.tree_structure(params)
opt_state_values, opt_state_treedef = jax.tree_flatten(
opt_state,
is_leaf=leaf_or_treedef(param_treedef),
)
modified_opt_state_values = tuple(
surgery(value) if jax.tree_structure(value) == param_treedef else value
for value in opt_state_values
)
modified_opt_state = jax.tree_unflatten(
opt_state_treedef, modified_opt_state_values)
modified_opt_state |
Beta Was this translation helpful? Give feedback.
-
Let's assume I have some model and want to manually modify the parameter tree - as described in
Flax's "Model Surgery" HOWTO.
Then I have some pytree of params that I can modify manually (e.g. removing or renaming weights), e.g.:
How do I best update the corresponding pytrees inside an Optax optimizer state?
I could do the following, but that feels verbose and brittle:
Beta Was this translation helpful? Give feedback.
All reactions