-
-
Notifications
You must be signed in to change notification settings - Fork 611
/
normalise.jl
491 lines (394 loc) · 15.3 KB
/
normalise.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
istraining() = false
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
_isactive(m) = isnothing(m.active) ? istraining() : m.active
_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
"""
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true)
The dropout function. If `active` is `true`,
for each input, either sets that input to `0` (with probability
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
This is used as a regularisation, i.e. it reduces overfitting during training.
If `active` is `false`, it just returns the input `x`.
Specify `rng` for custom RNGs instead of the default RNG.
Note that custom RNGs are only supported on the CPU.
Warning: when using this function, you have to manually manage the activation
state. Usually in fact, dropout is used while training
but is deactivated in the inference phase. This can be
automatically managed using the [`Dropout`](@ref) layer instead of the
`dropout` function.
The [`Dropout`](@ref) layer is what you should use in most scenarios.
"""
function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x
y = dropout_mask(rng, x, p, dims=dims)
return x .* y
end
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
realfptype = float(real(eltype(x)))
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end
# TODO move this to NNlib
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
"""
Dropout(p; dims=:, rng = rng_from_array())
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
To apply dropout along certain dimension(s), specify the `dims` keyword.
e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
(also called 2D dropout).
Specify `rng` to use a custom RNG instead of the default.
Custom RNGs are only supported on the CPU.
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
"""
mutable struct Dropout{F,D,R<:AbstractRNG}
p::F
dims::D
active::Union{Bool, Nothing}
rng::R
end
Dropout(p, dims, active) = Dropout(p, dims, active, rng_from_array())
function Dropout(p; dims=:, rng = rng_from_array())
@assert 0 ≤ p ≤ 1
Dropout(p, dims, nothing, rng)
end
@functor Dropout
trainable(a::Dropout) = (;)
function (a::Dropout)(x)
_isactive(a) || return x
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
end
testmode!(m::Dropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p)
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
print(io, ")")
end
"""
AlphaDropout(p; rng = rng_from_array())
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
The AlphaDropout layer ensures that mean and variance of activations
remain the same as before.
Does nothing to the input once [`testmode!`](@ref) is true.
"""
mutable struct AlphaDropout{F,R<:AbstractRNG}
p::F
active::Union{Bool, Nothing}
rng::R
function AlphaDropout(p, active, rng)
@assert 0 ≤ p ≤ 1
new{typeof(p), typeof(rng)}(p, active, rng)
end
end
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)
@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
p = a.p
iszero(p) && return x
isone(p) && return sign.(x) .* T(0)
α′ = T(-1.7580993408473766) # selu(-Inf) == -λα
A = T(inv(sqrt((1 - p) * (1 + p * α′^2))))
B = T(-A * α′ * p)
noise = rand!(a.rng, similar(x))
return A .* ifelse.(noise .> p, x, α′) .+ B
end
testmode!(m::AlphaDropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
"""
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
used with recurrent hidden states.
The argument `sz` should be an integer or a tuple of integers.
In the forward pass, the layer normalises the mean and standard
deviation of the input, the applied the elementwise activation `λ`.
The input is normalised along the first `length(sz)` dimensions
for tuple `sz`, along the first dimension for integer `sz`.
The input is expected to have first dimensions' size equal to `sz`.
If `affine=true` also applies a learnable shift and rescaling
using the [`Scale`](@ref) layer.
See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
"""
struct LayerNorm{F,D,T,N}
λ::F
diag::D
ϵ::T
size::NTuple{N,Int}
affine::Bool
end
function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
return LayerNorm(λ, diag, ϵ, size, affine)
end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)
@functor LayerNorm
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", join(l.size, ", "))
l.λ === identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
# For InstanceNorm, GroupNorm, and BatchNorm.
# Compute the statistics on the slices specified by reduce_dims.
# reduce_dims=[1,...,N-2,N] for BatchNorm
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
function _norm_layer_forward(
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
) where {T, N}
if !_isactive(l) && l.track_stats # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
else # trainmode or testmode without tracked stats
μ = mean(x; dims=reduce_dims)
σ² = mean((x .- μ).^2; dims=reduce_dims)
if l.track_stats
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
end
end
o = _norm_layer_forward(x, μ, σ², l.ϵ)
hasaffine(l) || return l.λ.(o)
γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
end
@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)
function _track_stats!(
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
V = eltype(bn.σ²)
mtm = bn.momentum
res_mtm = one(V) - mtm
m = prod(size(x, i) for i in reduce_dims)
μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N))
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
return nothing
end
ChainRulesCore.@non_differentiable _track_stats!(::Any...)
"""
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine = true, track_stats = true,
ϵ=1f-5, momentum= 0.1f0)
[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
input slice and normalises the input accordingly.
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias β and scale γ parameters.
After normalisation, elementwise activation `λ` is applied.
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
Use [`testmode!`](@ref) during inference.
# Examples
```julia
m = Chain(
Dense(28^2 => 64),
BatchNorm(64, relu),
Dense(64 => 10),
BatchNorm(10),
softmax)
```
"""
mutable struct BatchNorm{F,V,N,W}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving var
ϵ::N
momentum::N
affine::Bool
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
end
function BatchNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true,
ϵ=1f-5, momentum=0.1f0)
β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros32(chs) : nothing
σ² = track_stats ? ones32(chs) : nothing
return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
end
@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
reduce_dims = [1:N-2; N]
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
end
testmode!(m::BatchNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
(l.λ == identity) || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
"""
InstanceNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
input slice and normalises the input accordingly.
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
in previous Flux versions (< v0.12).
"""
mutable struct InstanceNorm{F,V,N,W}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving var
ϵ::N
momentum::N
affine::Bool
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
end
function InstanceNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros32(chs) : nothing
σ² = track_stats ? ones32(chs) : nothing
return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
end
@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
function (l::InstanceNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == l.chs
N = ndims(x)
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
return _norm_layer_forward(l, x; reduce_dims, affine_shape)
end
testmode!(m::InstanceNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
"""
GroupNorm(channels::Integer, G::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
[Group Normalization](https://arxiv.org/abs/1803.08494) layer.
`chs` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the `N-1`th index is the channel dimension.
`G` is the number of groups along which the statistics are computed.
The number of channels must be an integer multiple of the number of groups.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
"""
mutable struct GroupNorm{F,V,N,W}
G::Int # number of groups
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
affine::Bool
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
end
@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)")
β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros32(G) : nothing
σ² = track_stats ? ones32(G) : nothing
return GroupNorm(G, λ,
β, γ,
μ, σ²,
ϵ, momentum,
affine, track_stats,
nothing, chs)
end
function (gn::GroupNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == gn.chs
N = ndims(x)
sz = size(x)
x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N])
N = ndims(x)
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x, i) : 1, N)
x = _norm_layer_forward(gn, x; reduce_dims, affine_shape)
return reshape(x, sz)
end
testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::GroupNorm)
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
print(io, "GroupNorm($(l.chs), $(l.G)")
l.λ == identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
"""
hasaffine(l)
Return `true` if a normalisation layer has trainable shift and
scale parameters, `false` otherwise.
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine