forked from FluxML/Flux.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparallel.jl
72 lines (56 loc) · 2.17 KB
/
parallel.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
using Test, Random
using Flux
using Flux: @epochs
using Statistics: mean
using Base.Iterators: partition
using CuArrays
@testset "Parallel" begin
data = gpu.(collect(partition(rand(10, 7), 10)))
models = [
# non recurrent layers - the reduce function defaults to `Flux.concat`
Chain(Parallel([Dense(10,10), Dense(10,10)]), Dense(20,10)),
# recurrent layers
Parallel([LSTM(10,10)]),
Chain(Parallel([LSTM(10,10)])),
# for reduce see: `Base.sum`, `Statistics.mean`, `Flux.mul`, `Flux.concat`
Parallel([LSTM(10,5), LSTM(10,5)]),
Parallel([LSTM(10,10), LSTM(10,10)], reduce=sum),
Chain(Parallel([LSTM(10,10), LSTM(10,10)], reduce=mean)),
# reduce can be `Flux.concat`. Here the reduction is effectifly done by a final Dense layer:
Chain(Parallel([LSTM(10,10)]), Dense(10,10)),
Chain(Parallel([LSTM(10,10), LSTM(10,10)]), Dense(20,10)),
# bidirectional LSTM
Parallel([LSTM(10,10), LSTM(10,10)],
map = Dict{Int64,Function}(2 => reverse),
inv = Dict{Int64,Function}(2 => reverse),
reduce = sum),
# BiLSTM - a convenience layer, which makes use of `Parallel` and the MapReduce approach
Bi(LSTM(10, 10), sum),
Chain(Bi(LSTM(10,10), sum)),
]
@testset "models using a `Parallel` layer" for (i,m) in enumerate(models)
println("\n\ntest ($i)\n")
gpu(m)
@show m
before = Flux.data(m(data[1]))
@test length(before) == 10 || length(before) == 20
function loss(x, y)
l = Flux.mse(m(x), y)
Flux.truncate!(m)
l
end
function evalcb()
error = mean(map(x -> loss(x, x), data))
@show(error)
end
opt = ADAM()
@epochs 3 Flux.train!(loss, params(m), zip(data, data), opt, cb = evalcb)
Flux.reset!(m)
after = Flux.data(m(data[1]))
@test length(before) == length(after[:,end]) || length(before) == 2 * length(after[:,end])
@test before != after[:,end]
Flux.reset!(m)
after = Flux.data(m(data[1]))
@test before != after
end
end