Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for sharing layers #476

Merged
merged 1 commit into from
Sep 13, 2023
Merged

Added support for sharing layers #476

merged 1 commit into from
Sep 13, 2023

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Sep 7, 2023

Added support for sharing layers between different parts of a model.

This has the API eqx.nn.Shared(pytree, where, get). At init time, then everything in where(pytree) is deleted -- these as the duplicate nodes we want to remove. At call time, everything in get(pytree) is inserted in their place -- these as the nodes we're keeping, and in doing so we restore a layer that we can evaluate.

Example usage like so:

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Int

class LanguageModel(eqx.Module):
    shared: eqx.nn.Shared

    def __init__(self):
        embedding = eqx.nn.Embedding(...)
        linear = eqx.nn.Linear(...)
        where = lambda pair: pair[1].weight
        get = lambda pair: pair[0].weight
        self.shared = eqx.nn.Shared((embedding, linear), where, get)

    def __call__(self, tokens: Int[Array, "sequence"]):
        embedding, linear = self.shared()
        values = jax.vmap(embedding)(tokens)
        ...  # the rest of your language model goes here!
        return jax.vmap(linear)(values)

If we need to perform any kidn of transform, e.g. transposition, then that can be done just by placing it in get, e.g. get = lambda pair: jnp.transpose(pair[0].weight).

@patrick-kidger patrick-kidger changed the base branch from main to dev September 7, 2023 13:25
@patrick-kidger patrick-kidger marked this pull request as draft September 7, 2023 13:25
@patrick-kidger patrick-kidger changed the title shared Added support for sharing layers Sep 7, 2023
@patrick-kidger patrick-kidger marked this pull request as ready for review September 12, 2023 22:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant