Skip to content

Commit

Permalink
EmptyFixedSizedArrays Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2022
1 parent c900bfc commit cff0cf4
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# v0.4

## v0.4.10

- Introduces a testing array -- `EmptyFixedSizedArray`
- Allows static size inference for reshape layer.

## v0.4.8

- Deprecations
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.4.9"
version = "0.4.10"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 2 additions & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ include("autodiff.jl")
function __init__()
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("transform.jl")
end
# Contrib
include("contrib/emptyarrays.jl")

# Deprecations
include("deprecated.jl")
Expand Down
157 changes: 157 additions & 0 deletions src/contrib/emptyarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import NNlib, Statistics

"""
EmptyFixedSizedArray{T, N, S} <: AbstractArray{T, N}
An EmptyFixedSizedArray to test out static size inference of a model. This has multiple
usecases:
- Test out that the neural network works without actually doing expensive computations.
- Statically infer sizes of intermediate arrays. Especially useful if we want to generate
an XLA computation which requires static shape inference.
Semantics of the Array:
- `getfield` always returns T(0).
- `setfield` is a no-op.
"""
struct EmptyFixedSizedArray{T, N, S} <: AbstractArray{T, N} end

function EmptyFixedSizedArray(x::AbstractArray)
return EmptyFixedSizedArray{eltype(x), ndims(x), size(x)}()
end

function Base.show(io::IO, x::EmptyFixedSizedArray{T, N, S}) where {T, N, S}
print(io, "$(join(S, "x")) EmptyFixedSizedArray{$T, $N}")
return nothing
end
Base.show(io::IO, ::MIME, x::EmptyFixedSizedArray) = show(io, x)
function Base.display(x::EmptyFixedSizedArray)
show(stdout, x)
println()
return nothing
end

Base.size(::EmptyFixedSizedArray{T, N, S}) where {T, N, S} = S
Base.eltype(::EmptyFixedSizedArray{T}) where {T} = T
Base.getindex(::EmptyFixedSizedArray{T}, i...) where {T} = T(0)
Base.setindex!(::EmptyFixedSizedArray, i, v) = nothing

Base.similar(x::EmptyFixedSizedArray) = x
function Base.similar(::EmptyFixedSizedArray{T1, N, S}, ::Type{T}) where {T1, N, S, T}
return EmptyFixedSizedArray{T, N, S}()
end
function Base.similar(::EmptyFixedSizedArray, ::Type{T},
dims::Union{Integer, AbstractUnitRange}...) where {T}
dims = dims isa Integer ? (dims,) : dims
return EmptyFixedSizedArray{T, length(dims), dims}()
end
function Base.similar(x::EmptyFixedSizedArray{T},
dims::Union{Integer, AbstractUnitRange}...) where {T}
return similar(x, T, dims...)
end

function Base.reshape(x::EmptyFixedSizedArray, ::Val{shape}) where {shape}
return reshape(x, shape...)
end

# NOTE(@avik-pal): Type Inference not possible
function Base.reshape(x::EmptyFixedSizedArray{T, N, S},
dims::Union{Colon, Int, UnitRange}...) where {T, N, S}
dims_ = filter(x -> !isa(x, Colon), dims)
colons = length(dims) - length(dims_)
@assert colons<=1 AssertionError("Atmax 1 Colon() is allowed in `dims`.")
if colons == 1
cidx = findfirst(x -> isa(x, Colon), dims)
dims = (dims[1:(cidx - 1)]..., div(prod(S), prod(dims_)), dims[(cidx + 1):end]...)
end
@assert prod(dims)==prod(S) AssertionError("Array of size $S cannot be reshaped " *
"into size $dims.")
return EmptyFixedSizedArray{T, length(dims), dims}()
end

# NOTE(@avik-pal): Type Inference not possible
function Base.view(x::EmptyFixedSizedArray{T},
dims::Union{Colon, Int, UnitRange}...) where {T}
dims_ = to_indices(x, dims)
return EmptyFixedSizedArray{T, length(dims_), dims_}()
end

function Base.:+(::EmptyFixedSizedArray{T1, N, S},
::EmptyFixedSizedArray{T2, N, S}) where {T1, T2, N, S}
T = promote_type(T1, T2)
return EmptyFixedSizedArray{T, N, S}()
end

function Base.:-(::EmptyFixedSizedArray{T1, N, S},
::EmptyFixedSizedArray{T2, N, S}) where {T1, T2, N, S}
T = promote_type(T1, T2)
return EmptyFixedSizedArray{T, N, S}()
end

function Base.:*(::EmptyFixedSizedArray{T1, 2, S1},
::EmptyFixedSizedArray{T2, 2, S2}) where {T1, T2, S1, S2}
@assert S1[2]==S2[1] AssertionError("Sizes $S1 and $S2 are not compatible for " *
"matrix multiplication.")
T = promote_type(T1, T2)
return EmptyFixedSizedArray{T, 2, (S1[1], S2[2])}()
end

function Base.BroadcastStyle(::Type{<:EmptyFixedSizedArray})
return Broadcast.ArrayStyle{EmptyFixedSizedArray}()
end

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{EmptyFixedSizedArray}},
::Type{ElType}) where {ElType}
return EmptyFixedSizedArray{ElType, length(axes(bc)), length.(axes(bc))}()
end

Base.copyto!(dest::EmptyFixedSizedArray, bc::Base.Broadcast.Broadcasted) = dest

function NNlib.conv!(out::EmptyFixedSizedArray, in1::EmptyFixedSizedArray,
in2::EmptyFixedSizedArray, cdims::NNlib.DenseConvDims; kwargs...)
return out
end

function NNlib.maxpool!(out::EmptyFixedSizedArray, x::EmptyFixedSizedArray,
pdims::NNlib.PoolDims; kwargs...)
return out
end

function NNlib.meanpool!(out::EmptyFixedSizedArray, x::EmptyFixedSizedArray,
pdims::NNlib.PoolDims; kwargs...)
return out
end

@inline function _reshape_into_proper_shape(x::EmptyFixedSizedArray,
y::EmptyFixedSizedArray)
return reshape(x, _get_reshape_dims(size(y), length(x))...)
end

@generated function _compute_reduced_dimensions(::EmptyFixedSizedArray{T, N, shape},
::Val{dims}) where {T, N, dims, shape}
@assert minimum(dims) > 0 && maximum(dims) <= N
d = dims isa Int ? (dims,) : (dims isa Vector ? Tuple(dims) : dims)
res = ntuple(i -> i in d ? 1 : shape[i], N)
return :(return $res)
end

function _compute_reduced_dimensions(x::EmptyFixedSizedArray, dims)
return _compute_reduced_dimensions(x, Val(dims))
end

function _generic_reduction(x::EmptyFixedSizedArray{T, N}, dims::Val) where {T, N}
return EmptyFixedSizedArray{T, N, _compute_reduced_dimensions(x, dims)}()
end

Base._sum(x::EmptyFixedSizedArray{T}, ::Colon) where {T, N} = T(0)
function Base._sum(x::EmptyFixedSizedArray{T, N}, dims) where {T, N}
return EmptyFixedSizedArray{T, N, _compute_reduced_dimensions(x, dims)}()
end
Base._sum(f::Function, x::EmptyFixedSizedArray{T}, dims::Colon) where {T} = T(0)
function Base._sum(f::Function, x::EmptyFixedSizedArray{T, N}, dims) where {T, N}
return EmptyFixedSizedArray{T, N, _compute_reduced_dimensions(x, dims)}()
end

Statistics._mean(::Function, x::EmptyFixedSizedArray, dims) = Base._sum(x, dims)
Statistics._var(x::EmptyFixedSizedArray, corrected::Bool, mean, dims) = Base._sum(x, dims)
14 changes: 7 additions & 7 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ Reshapes the passed array to have a size of `(dims..., :)`
- AbstractArray of size `(dims..., size(x, ndims(x)))`
- Empty `NamedTuple()`
"""
struct ReshapeLayer{N} <: AbstractExplicitLayer
dims::NTuple{N, Int}
end
struct ReshapeLayer{dims} <: AbstractExplicitLayer end

ReshapeLayer(dims) = ReshapeLayer{dims}()

@inline function (r::ReshapeLayer)(x::AbstractArray, ps, st::NamedTuple)
return reshape(x, r.dims..., size(x, ndims(x))), st
@inline function (r::ReshapeLayer{dims})(x::AbstractArray, ps, st::NamedTuple) where {dims}
return reshape(x, dims..., size(x, ndims(x))), st
end

function Base.show(io::IO, r::ReshapeLayer)
return print(io, "ReshapeLayer(output_dims = (", join(r.dims, ", "), ", :))")
function Base.show(io::IO, r::ReshapeLayer{dims}) where {dims}
return print(io, "ReshapeLayer(output_dims = (", join(dims, ", "), ", :))")
end

"""
Expand Down
8 changes: 6 additions & 2 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_sta
return (track_stats ? 2 * l.chs : 0) + 1
end

_bn_reduce_dims(::Val{N}) where {N} = Val(filter(i -> i != N - 1, ntuple(identity, N)))

function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale,
ps.bias, BN.activation,
collect([1:(N - 2); N]), st.training,
_bn_reduce_dims(Val(N)), st.training,
BN.momentum, BN.epsilon)

st = merge(st, (running_mean=xmean, running_var=xvar))
Expand Down Expand Up @@ -306,12 +308,14 @@ function statelength(l::GroupNorm{affine, track_stats}) where {affine, track_sta
return (track_stats ? 2 * l.groups : 0) + 1
end

# FIXME(@avik-pal): Static Shape Inference requires us to store the group count as a type
# parameter.
function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
sz = size(x)
x_ = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ GN.groups, GN.groups, sz[N])

x_normalized, xmean, xvar = normalization(x_, st.running_mean, st.running_var, ps.scale,
ps.bias, GN.activation, collect(1:(N - 1)),
ps.bias, GN.activation, Val(Tuple(1:(N - 1))),
st.training, GN.momentum, GN.epsilon)

st = merge(st, (running_mean=xmean, running_var=xvar))
Expand Down
16 changes: 8 additions & 8 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
running_var::AbstractArray{T, N},
batchmean::AbstractArray{T, N},
batchvar::AbstractArray{T, N}, momentum::T,
reduce_dims) where {T, N}
::Val{reduce_dims}) where {T, N, reduce_dims}
sx = size(x)
m = T(prod((sx[i] for i in reduce_dims)))
if reduce_dims[end] != N
Expand All @@ -19,7 +19,7 @@
end

"""
normalization(x, running_mean, running_var, scale, bias, activation, reduce_dims,
normalization(x, running_mean, running_var, scale, bias, activation, ::Val{reduce_dims},
::Val{training}, momentum, epsilon)
Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration
Expand Down Expand Up @@ -50,10 +50,11 @@ end

@generated function normalization_forward(x::AbstractArray{T, N}, running_mean::RM,
running_var::RV, scale::S, bias::B, activation::A,
reduce_dims, ::Val{training},
r::Val{reduce_dims}, ::Val{training},
momentum::T=T(0.1f0),
epsilon::T=T(1.0f-5)) where {RM, RV, S, B, T, N,
A, training}
A, training,
reduce_dims}
calls = []
if !training
if RM == Nothing
Expand All @@ -73,21 +74,20 @@ end
push!(calls,
:((running_mean, running_var) = update_statistics(x, running_mean,
running_var, batchmean,
batchvar, momentum,
reduce_dims)))
batchvar, momentum, r)))
end
end

expr = if S != Nothing
if A == typeof(identity)
:(result = scale .* (x .- batchmean) ./ sqrt.(batchvar .+ epsilon) .+ bias)
:(result = (scale .* (x .- batchmean) ./ sqrt.(batchvar .+ epsilon) .+ bias))
else
:(result = activation.(scale .* (x .- batchmean) ./
sqrt.(batchvar .+ epsilon) .+ bias))
end
else
if A == typeof(identity)
:(result = (x .- batchmean) ./ sqrt.(batchvar .+ epsilon))
:(result = ((x .- batchmean) ./ sqrt.(batchvar .+ epsilon)))
else
:(result = activation.((x .- batchmean) ./ sqrt.(batchvar .+ epsilon)))
end
Expand Down
1 change: 1 addition & 0 deletions test/contrib/emptyarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

0 comments on commit cff0cf4

Please sign in to comment.