Skip to content

Commit

Permalink
Create neurons
Browse files Browse the repository at this point in the history
  • Loading branch information
bhoov committed Dec 14, 2023
1 parent f3302fb commit 06ae800
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
pythonpath = [
"src"
]
]

[tool.ruff]
indent-width = 2
40 changes: 40 additions & 0 deletions src/neurons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from typing import *

class Neurons(eqx.Module):
"""Neurons represent dynamic variables in the HAM that are evolved during inference (i.e., memory retrieval/error correction)
They have an evolving state (created using the `.init` function) that is stored outside the neuron layer itself
"""
lagrangian: Callable
shape: Tuple[int]

def __init__(
self, lagrangian: Union[Callable, eqx.Module], shape: Union[int, Tuple[int]]
):
super().__init__()
self.lagrangian = lagrangian
if isinstance(shape, int):
shape = (shape,)
self.shape = shape

def activations(self, x: jax.Array) -> jax.Array:
return jax.grad(self.lagrangian)(x)

def g(self, x: jax.Array) -> jax.Array:
return self.activations(x)

def energy(self, g: jax.Array, x: jax.Array) -> jax.Array:
"""Assume vectorized"""
return jnp.multiply(g, x).sum() - self.lagrangian(x)

def init(self, bs: Optional[int] = None) -> jax.Array:
"""Return an empty state of the correct shape"""
if bs is None or bs == 0:
return jnp.zeros(self.shape)
return jnp.zeros((bs, *self.shape))

def __repr__(self: jax.Array):
return f"Neurons(lagrangian={self.lagrangian}, shape={self.shape})"
18 changes: 18 additions & 0 deletions tests/test_neurons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from neurons import Neurons
from lagrangians import lagr_softmax
import jax.numpy as jnp
import jax

neuron_shape = (5,)
beta = 3.
neuron = Neurons(lagrangian=lambda x: lagr_softmax(x, beta=beta), shape=neuron_shape)
act_fn = lambda x: jax.nn.softmax(beta * x)

def test_init():
assert neuron.init().shape == neuron_shape
assert neuron.init(bs=3).shape == (3, *neuron_shape)

def test_activations():
x = neuron.init()
assert jnp.all(neuron.activations(x) == neuron.g(x))
assert jnp.allclose(act_fn(x), neuron.g(x))

0 comments on commit 06ae800

Please sign in to comment.