Skip to content

Commit

Permalink
Embedding now supports initialistion with just a weight.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 5, 2023
1 parent a234840 commit 5e39522
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
47 changes: 35 additions & 12 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import jax.random as jrandom
from jaxtyping import Array, Float, PRNGKeyArray
from jaxtyping import Array, Float, Int, PRNGKeyArray

from .._module import field, Module

Expand All @@ -16,28 +16,49 @@ class Embedding(Module):

def __init__(
self,
num_embeddings: int,
embedding_size: int,
num_embeddings: Optional[int] = None, # pyright: ignore
embedding_size: Optional[int] = None, # pyright: ignore
weight: Optional[Float[Array, "num_embeddings embedding_size"]] = None,
*,
key: PRNGKeyArray,
key: Optional[PRNGKeyArray] = None,
**kwargs,
):
"""**Arguments:**
`Embedding` should be initialised with either:
- `num_embeddings`: Size of embedding dictionary. Must be non-negative.
- `embedding_size`: Size of each embedding vector. Must be non-negative.
- `weight`: If given, the embedding lookup table. Will be generated randomly
if not provided.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
- `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation
of the embedding lookup table. (Keyword only argument.)
Or:
- `weight`: The embedding lookup table, of shape
`(num_embeddings, embedding_size)`.
"""
super().__init__(**kwargs)
if weight is None:
assert num_embeddings >= 0, "num_embeddings must not be negative."
assert embedding_size >= 0, "embedding_size must not be negative."
if num_embeddings is None or embedding_size is None or key is None:
raise ValueError(
"Must provide `eqx.nn.Embedding(num_embeddings=..., "
"embedding_size=..., key=...)` if not providing the weight "
"directly."
)
if num_embeddings < 0:
raise ValueError("num_embeddings must not be negative.")
if embedding_size < 0:
raise ValueError("embedding_size must not be negative.")
self.weight = jrandom.normal(key, (num_embeddings, embedding_size))
else:
if weight.ndim != 2:
raise ValueError(
"weight must have shape (num_embeddings, embedding_size)."
)
if num_embeddings is None:
num_embeddings: int = weight.shape[0]
if embedding_size is None:
embedding_size: int = weight.shape[1]
if weight.shape != (num_embeddings, embedding_size):
raise ValueError(
"weight must have shape (num_embeddings, embedding_size)."
Expand All @@ -47,10 +68,12 @@ def __init__(
self.embedding_size = embedding_size

@jax.named_scope("eqx.nn.Embedding")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
def __call__(
self, x: Int[Array, ""], *, key: Optional[PRNGKeyArray] = None
) -> Array:
"""**Arguments:**
- `x`: The table index.
- `x`: The table index. Should be a scalar integer array.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,10 @@ def test_embedding(getkey):
x = jnp.array([-1])
assert jnp.allclose(emb(x), jnp.linspace(9.1, 10.0, 10))

emb = eqx.nn.Embedding(weight=jnp.linspace(0.1, 10, 100).reshape(10, 10))
x = jnp.array([-1])
assert jnp.allclose(emb(x), jnp.linspace(9.1, 10.0, 10))


def test_layer_norm(getkey):
ln = eqx.nn.LayerNorm(128)
Expand Down

0 comments on commit 5e39522

Please sign in to comment.