-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
62 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import equinox as eqx | ||
from typing import * | ||
|
||
class Neurons(eqx.Module): | ||
"""Neurons represent dynamic variables in the HAM that are evolved during inference (i.e., memory retrieval/error correction) | ||
They have an evolving state (created using the `.init` function) that is stored outside the neuron layer itself | ||
""" | ||
lagrangian: Callable | ||
shape: Tuple[int] | ||
|
||
def __init__( | ||
self, lagrangian: Union[Callable, eqx.Module], shape: Union[int, Tuple[int]] | ||
): | ||
super().__init__() | ||
self.lagrangian = lagrangian | ||
if isinstance(shape, int): | ||
shape = (shape,) | ||
self.shape = shape | ||
|
||
def activations(self, x: jax.Array) -> jax.Array: | ||
return jax.grad(self.lagrangian)(x) | ||
|
||
def g(self, x: jax.Array) -> jax.Array: | ||
return self.activations(x) | ||
|
||
def energy(self, g: jax.Array, x: jax.Array) -> jax.Array: | ||
"""Assume vectorized""" | ||
return jnp.multiply(g, x).sum() - self.lagrangian(x) | ||
|
||
def init(self, bs: Optional[int] = None) -> jax.Array: | ||
"""Return an empty state of the correct shape""" | ||
if bs is None or bs == 0: | ||
return jnp.zeros(self.shape) | ||
return jnp.zeros((bs, *self.shape)) | ||
|
||
def __repr__(self: jax.Array): | ||
return f"Neurons(lagrangian={self.lagrangian}, shape={self.shape})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from neurons import Neurons | ||
from lagrangians import lagr_softmax | ||
import jax.numpy as jnp | ||
import jax | ||
|
||
neuron_shape = (5,) | ||
beta = 3. | ||
neuron = Neurons(lagrangian=lambda x: lagr_softmax(x, beta=beta), shape=neuron_shape) | ||
act_fn = lambda x: jax.nn.softmax(beta * x) | ||
|
||
def test_init(): | ||
assert neuron.init().shape == neuron_shape | ||
assert neuron.init(bs=3).shape == (3, *neuron_shape) | ||
|
||
def test_activations(): | ||
x = neuron.init() | ||
assert jnp.all(neuron.activations(x) == neuron.g(x)) | ||
assert jnp.allclose(act_fn(x), neuron.g(x)) |