-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from JuliaGNI/static_arrays
Added `CPUStatic` backend and implemented new `initialparameters` interface.
- Loading branch information
Showing
36 changed files
with
326 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# pre-push git hook that runs all tests before pushing | ||
|
||
red='\033[0;31m' | ||
green='\033[0;32m' | ||
no_color='\033[0m' | ||
|
||
reponame=$(basename `git rev-parse --show-toplevel`) | ||
|
||
|
||
echo "\nRunning pre-push hook\n" | ||
echo "Testing $reponame" | ||
julia --project=@. -e "using Pkg; Pkg.test(\"AbstractNeuralNetworks\")" | ||
|
||
if [[ $? -ne 0 ]]; then | ||
echo "\n${red}ERROR - Tests must pass before push!\n${no_color}" | ||
exit 1 | ||
fi | ||
|
||
echo "\n${green}Git hook was SUCCESSFUL!${no_color}\n" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
[deps] | ||
AbstractNeuralNetworks = "60874f82-5ada-4c70-bd1c-fa6be7711c8a" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" | ||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
@inproceedings{glorot2010understanding, | ||
title={Understanding the difficulty of training deep feedforward neural networks}, | ||
author={Glorot, Xavier and Bengio, Yoshua}, | ||
booktitle={Proceedings of the thirteenth international conference on artificial intelligence and statistics}, | ||
pages={249--256}, | ||
year={2010}, | ||
organization={JMLR Workshop and Conference Proceedings} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# References | ||
|
||
```@bibliography | ||
* | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Static Neural Network Parameters | ||
|
||
We can also allocate neural network parameters using [`StaticArrays`](https://github.com/JuliaArrays/StaticArrays.jl). Therefore we simply need to set the keyword `static` to true in the [`NeuralNetwork`](@ref) constructor. | ||
|
||
!!! warning | ||
Static neural network parameters are only supported for dense CPU arrays. `AbstractNeuralNetworks` defines a type `CPUStatic`, but does not have equivalent GPU objects. | ||
|
||
```@example static_parameters | ||
using AbstractNeuralNetworks | ||
import Random | ||
Random.seed!(123) | ||
backend = AbstractNeuralNetworks.CPUStatic() | ||
input_dim = 2 | ||
n_hidden_layers = 100 | ||
c = Chain(Dense(input_dim, 10, tanh), Tuple(Dense(10, 10, tanh) for _ in 1:n_hidden_layers)..., Dense(10, 1, tanh)) | ||
nn = NeuralNetwork(c, backend) | ||
typeof(nn.params.L1.W) | ||
``` | ||
|
||
We can compare different evaluation times: | ||
```@example static_parameters | ||
nn_cpu = changebackend(CPU(), nn) | ||
second_dim = 200 | ||
x = rand(input_dim, second_dim) | ||
nn(x); # hide | ||
@time nn(x); | ||
nothing # hide | ||
``` | ||
|
||
```@example static_parameters | ||
nn_cpu(x); # hide | ||
@time nn_cpu(x); | ||
nothing # hide | ||
``` | ||
|
||
If we also make the *input* static, we get: | ||
|
||
```@example static_parameters | ||
using StaticArrays | ||
x = @SMatrix rand(input_dim, second_dim) | ||
nn(x); | ||
@time nn(x); | ||
nothing # hide | ||
``` | ||
|
||
```@example static_parameters | ||
nn_cpu(x); # hide | ||
@time nn_cpu(x); | ||
nothing # hide | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
|
||
""" | ||
Architecture | ||
""" | ||
abstract type Architecture end | ||
|
||
struct UnknownArchitecture <: Architecture end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,43 @@ | ||
""" | ||
Initializer | ||
abstract type AbstractInitializer end | ||
Determines how neural network weights are initialized. | ||
""" | ||
abstract type Initializer end | ||
|
||
const Initializer = Union{AbstractInitializer, Base.Callable} | ||
""" | ||
ZeroInitializer <: Initializer | ||
""" | ||
struct ZeroInitializer <: Initializer end | ||
|
||
struct ZeroInitializer <: AbstractInitializer end | ||
function (::ZeroInitializer)(_, x) | ||
x .= KernelAbstractions.zero(x) | ||
|
||
nothing | ||
end | ||
|
||
struct OneInitializer <: AbstractInitializer end | ||
""" | ||
OneInitializer <: Initializer | ||
""" | ||
struct OneInitializer <: Initializer end | ||
|
||
function (::OneInitializer)(_, x::AbstractArray{T}) where T | ||
backend = get_backend(x) | ||
backend = networkbackend(x) | ||
x .= KernelAbstractions.ones(backend, T, size(x)) | ||
|
||
nothing | ||
end | ||
|
||
default_initializer() = randn! | ||
""" | ||
GlorotUniform <: Initializer | ||
struct GlorotUniform <: AbstractNeuralNetworks.AbstractInitializer end | ||
Glorot uniform was introduced by [glorot2010understanding](@cite). | ||
""" | ||
struct GlorotUniform <: Initializer end | ||
|
||
function (::GlorotUniform)(rng, x::AbstractVecOrMat{T}) where T | ||
rand!(rng, x) | ||
x .= sqrt(T(24.0) / sum(size(x))) * (x .- T(0.5)) | ||
end | ||
end | ||
|
||
const DefaultInitializer = GlorotUniform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.