Unofficial implementation of SIREN neural networks in Flax, using the Linen Module system.
This repo also includes Modulated Periodic Activations for Generalizable Local Functional Representations.
An image fitting problem is provided in the Example notebook
Returns a fully connected layer with sinusoidal activation function, initialized according to the original SIREN paper.
layer = SirenLayer(
features = 32
w0 = 1.0
c = 6.0
is_first = False
use_bias = True
act = jnp.sin
precision = None
dtype = jnp.float32
)
SirenNN = Siren(hidden_dim=512, output_dim=1, final_activation=sigmoid)
params = SirenNN.init(random_key, sample_input)["params"]
output = SirenNN.apply({"params": params}, sample_input)
This can be easily done using the built-in broadcasting features of jax.numpy
functions. This repository provides an useful initializer grid_init
to generate a coordinate grid that can be used as input.
SirenDef = Siren(num_layers=5)
grid = grid_init(grid_dimension, jnp.float32)()
params = SirenDef.init(key, grid)["params"]
image = SirenDef.apply({"params": params}, grid)
SirenDef = ModulatedSiren(num_layers=5)
grid = grid_init(grid_dimension, jnp.float32)()
params = SirenDef.init(key, grid)["params"]
image = SirenDef.apply({"params": params}, grid)