Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to mark parameters as trainable or not? #866

Open
tomsch420 opened this issue Sep 28, 2024 · 5 comments
Open

How to mark parameters as trainable or not? #866

tomsch420 opened this issue Sep 28, 2024 · 5 comments
Labels
question User queries

Comments

@tomsch420
Copy link

Greetings!

I got custom Layers in equinox that look approximately like this.

class ProductLayer(InnerLayer):

    child_layers: List[Union[SumLayer, InputLayer]]
    edges: BCOO

class SumLayer(InnerLayer):

    log_weights: List[BCOO]
    child_layers: Union[List[[ProductLayer]], List[InputLayer]]

class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
    interval: jax.Array

I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent.
Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?

@patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger added the question User queries label Sep 29, 2024
@danielward27
Copy link
Contributor

There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization.

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw


class Model(eqx.Module):
    buffer: Array
    param: Array

    def __call__(self, x):
        return self.param * x + jax.lax.stop_gradient(self.buffer)

@eqx.filter_value_and_grad
def loss(model, x):
    return model(x)

model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1)  # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(())  # Fails!

Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class (NonTrainable) to wrap non-trainable nodes, and partitioning parameters e.g. with:

params, static = eqx.partition(
        model,
        eqx.is_inexact_array,
        is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
    )

@patrick-kidger
Copy link
Owner

Ah! That really isn't very good, you're right.

Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for eqx.apply_updates) with something that respects such a Nontrainable wrapper. This is just such an easy footgun!

@dlwh
Copy link
Contributor

dlwh commented Oct 22, 2024

FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to).

For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with ensure_compile_time_eval

@patrick-kidger
Copy link
Owner

Ah, nice! Okay, I think I'm convinced.

I'd be happy to take a PR implementing this, then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants