From 06ae800761e8943d96adc4751c9a90cdf3d6ce3a Mon Sep 17 00:00:00 2001 From: Ben Hoover Date: Thu, 14 Dec 2023 11:30:06 -0600 Subject: [PATCH] Create neurons --- pyproject.toml | 5 ++++- src/neurons.py | 40 ++++++++++++++++++++++++++++++++++++++++ tests/test_neurons.py | 18 ++++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 src/neurons.py create mode 100644 tests/test_neurons.py diff --git a/pyproject.toml b/pyproject.toml index f18720b..a42245e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,4 +29,7 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] pythonpath = [ "src" -] \ No newline at end of file +] + +[tool.ruff] +indent-width = 2 \ No newline at end of file diff --git a/src/neurons.py b/src/neurons.py new file mode 100644 index 0000000..a210ccf --- /dev/null +++ b/src/neurons.py @@ -0,0 +1,40 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from typing import * + +class Neurons(eqx.Module): + """Neurons represent dynamic variables in the HAM that are evolved during inference (i.e., memory retrieval/error correction) + + They have an evolving state (created using the `.init` function) that is stored outside the neuron layer itself + """ + lagrangian: Callable + shape: Tuple[int] + + def __init__( + self, lagrangian: Union[Callable, eqx.Module], shape: Union[int, Tuple[int]] + ): + super().__init__() + self.lagrangian = lagrangian + if isinstance(shape, int): + shape = (shape,) + self.shape = shape + + def activations(self, x: jax.Array) -> jax.Array: + return jax.grad(self.lagrangian)(x) + + def g(self, x: jax.Array) -> jax.Array: + return self.activations(x) + + def energy(self, g: jax.Array, x: jax.Array) -> jax.Array: + """Assume vectorized""" + return jnp.multiply(g, x).sum() - self.lagrangian(x) + + def init(self, bs: Optional[int] = None) -> jax.Array: + """Return an empty state of the correct shape""" + if bs is None or bs == 0: + return jnp.zeros(self.shape) + return jnp.zeros((bs, *self.shape)) + + def __repr__(self: jax.Array): + return f"Neurons(lagrangian={self.lagrangian}, shape={self.shape})" diff --git a/tests/test_neurons.py b/tests/test_neurons.py new file mode 100644 index 0000000..63662e8 --- /dev/null +++ b/tests/test_neurons.py @@ -0,0 +1,18 @@ +from neurons import Neurons +from lagrangians import lagr_softmax +import jax.numpy as jnp +import jax + +neuron_shape = (5,) +beta = 3. +neuron = Neurons(lagrangian=lambda x: lagr_softmax(x, beta=beta), shape=neuron_shape) +act_fn = lambda x: jax.nn.softmax(beta * x) + +def test_init(): + assert neuron.init().shape == neuron_shape + assert neuron.init(bs=3).shape == (3, *neuron_shape) + +def test_activations(): + x = neuron.init() + assert jnp.all(neuron.activations(x) == neuron.g(x)) + assert jnp.allclose(act_fn(x), neuron.g(x)) \ No newline at end of file