Skip to content

Latest commit

 

History

History
73 lines (50 loc) · 3.74 KB

README.md

File metadata and controls

73 lines (50 loc) · 3.74 KB

SymbolicNeuralNetworks.jl

Stable Latest Build Status Coverage PkgEval

In a perfect world we probably would not need SymbolicNeuralNetworks. Its motivation mainly comes from Zygote's inability to handle second-order derivatives in a decent way1. We also note that if Enzyme matures further, there may be no need for SymoblicNeuralNetworks anymore in the future. For now (December 2024) SymbolicNeuralNetworks offer a good way to incorporate derivatives into the loss function.

SymbolicNeuralNetworks was created to take advantage of Symbolics for training neural networks by accelerating their evaluation and by simplifying the computation of arbitrary derivatives of the neural network. This package is based on AbstractNeuralNetwork and can be applied to GeometricMachineLearning.

SymbolicNeuralNetworks creates a symbolic expression of the neural network, computes arbitrary combinations of derivatives and uses RuntimeGeneratedFunctions to compile a Julia function.

To create a symbolic neural network, we first design a model with AbstractNeuralNetwork:

using AbstractNeuralNetworks

c = Chain(Dense(2, 2, tanh), Linear(2, 1))

We now call SymbolicNeuralNetwork:

using SymbolicNeuralNetworks

nn = SymbolicNeuralNetwork(c)

Example

We now train the neural network by using SymbolicPullback2:

pb = SymbolicPullback(nn)

using GeometricMachineLearning

# we generate the data and process them with `GeometricMachineLearning.DataLoader`
x_vec = -1.:.1:1.
y_vec = -1.:.1:1.
xy_data = hcat([[x, y] for x in x_vec, y in y_vec]...)
f(x::Vector) = exp.(-sum(x.^2))
z_data = mapreduce(i -> f(xy_data[:, i]), hcat, axes(xy_data, 2))

dl = DataLoader(xy_data, z_data)

nn_cpu = NeuralNetwork(c, CPU())
o = Optimizer(AdamOptimizer(), nn_cpu)
n_epochs = 1000
batch = Batch(10)
o(nn_cpu, dl, batch, n_epochs, pb.loss, pb)

We can also train the neural network with Zygote-based3 automatic differentiation (AD):

pb_zygote = GeometricMachineLearning.ZygotePullback(FeedForwardLoss())
o(nn_cpu, dl, batch, n_epochs, pb_zygote.loss, pb_zygote)

Development

We are using git hooks, e.g., to enforce that all tests pass before pushing. In order to activate these hooks, the following command must be executed once:

git config core.hooksPath .githooks

Footnotes

  1. In some cases it is possible to perform second-order differentiation with Zygote, but when this is possible and when it is not is not entirely clear.

  2. This example is discussed in detail in the docs.

  3. Note that here we can actually use Zygote without problems as it does not involve any complicated derivatives.