1
1
# For a Colab of this example, goto https://colab.research.google.com/drive/1xfUsBn9GEqbRjBF-UX_jnGjHZNtNsMae
2
2
using TensorBoardLogger
3
- using Flux
3
+ using Flux, ChainRulesCore
4
4
using Logging
5
5
using MLDatasets
6
6
using Statistics
7
- using CuArrays
8
7
9
8
# create tensorboard logger
10
9
logdir = " content/log"
@@ -28,62 +27,49 @@ model = Chain(
28
27
Dense (28 ^ 2 , 32 , relu),
29
28
Dense (32 , 10 ),
30
29
softmax
31
- )
30
+ ) |> gpu
32
31
33
- loss (x , y) = Flux. crossentropy (model (x) , y)
32
+ loss_fn (pred , y) = Flux. crossentropy (pred , y)
34
33
35
- accuracy (x , y) = mean (Flux. onecold (model (x) |> cpu) .== Flux. onecold (y |> cpu))
34
+ accuracy (pred , y) = mean (Flux. onecold (pred |> cpu) .== Flux. onecold (y |> cpu))
36
35
37
- opt = ADAM ()
36
+ opt = Flux . setup ( ADAM (), model )
38
37
39
- traindata = permutedims ( reshape (traindata, (28 , 28 , 60000 , 1 )), ( 1 , 2 , 4 , 3 ));
40
- testdata = permutedims ( reshape (testdata, (28 , 28 , 10000 , 1 )), ( 1 , 2 , 4 , 3 ));
38
+ traindata = reshape (traindata, (28 , 28 , 1 , 60000 ));
39
+ testdata = reshape (testdata, (28 , 28 , 1 , 10000 ));
41
40
trainlabels = Flux. onehotbatch (trainlabels, collect (0 : 9 ));
42
41
testlabels = Flux. onehotbatch (testlabels, collect (0 : 9 ));
43
42
44
- # function to get dictionary of model parameters
45
- function fill_param_dict! (dict, m, prefix)
46
- if m isa Chain
47
- for (i, layer) in enumerate (m. layers)
48
- fill_param_dict! (dict, layer, prefix* " layer_" * string (i)* " /" * string (layer)* " /" )
49
- end
50
- else
51
- for fieldname in fieldnames (typeof (m))
52
- val = getfield (m, fieldname)
53
- if val isa AbstractArray
54
- val = vec (val)
55
- end
56
- dict[prefix* string (fieldname)] = val
57
- end
58
- end
43
+ # functions to log information
44
+ function log_train (pred, y)
45
+ @info " train" loss= loss_fn (pred, y) acc= accuracy (pred, y)
59
46
end
60
-
61
- # function to log information after every epoch
62
- function TBCallback ()
63
- param_dict = Dict {String, Any} ()
64
- fill_param_dict! (param_dict, model, " " )
65
- with_logger (logger) do
66
- @info " model" params= param_dict log_step_increment= 0
67
- @info " train" loss= loss (traindata, trainlabels) acc= accuracy (traindata, trainlabels) log_step_increment= 0
68
- @info " test" loss= loss (testdata, testlabels) acc= accuracy (testdata, testlabels)
69
- end
47
+ function log_val ()
48
+ params_vec, _ = Flux. destructure (model)
49
+ @info " train" model= params_vec log_step_increment= 0
50
+ @info " test" loss= loss_fn (model (testdata), testlabels) acc= accuracy (model (testdata), testlabels)
70
51
end
71
52
72
- minibatches = []
73
- batchsize = 100
74
- for i in range (1 , stop = trainsize÷ batchsize)
75
- lbound = (i- 1 )* batchsize+ 1
76
- ubound = min (trainsize, i* batchsize)
77
- push! (minibatches, (traindata[:, :, :, lbound: ubound], trainlabels[:, lbound: ubound]))
78
- end
53
+ # trainloader = Flux.DataLoader((data=traindata, label=trainlabels), batchsize=100, shuffle=true, buffer=true, parallel=true) |> gpu ;
54
+ # testloader = Flux.DataLoader((data=testdata, label=testlabels), batchsize=100, shuffle=false, buffer=true) |> gpu ;
79
55
80
- Move data and model to gpu
81
- traindata = traindata |> gpu
82
- testdata = testdata |> gpu
83
- trainlabels = trainlabels |> gpu
84
- testlabels = testlabels |> gpu
85
- model = model |> gpu
86
- minibatches = minibatches |> gpu
56
+ trainloader = Flux. DataLoader ((data= traindata, label= trainlabels), batchsize= 100 ) |> gpu ;
57
+ testloader = Flux. DataLoader ((data= testdata, label= testlabels), batchsize= 100 ) |> gpu ;
87
58
88
59
# Train
89
- @Flux . epochs 15 Flux. train! (loss, params (model), minibatches, opt, cb = Flux. throttle (TBCallback, 5 ))
60
+ with_logger (logger) do
61
+ for epoch in 1 : 15
62
+ println (" epoch $epoch " )
63
+ for (x, y) in trainloader
64
+ loss, grads = Flux. withgradient (model) do m
65
+ pred = m (x)
66
+ ChainRulesCore. ignore_derivatives () do
67
+ log_train (pred, y)
68
+ end
69
+ loss_fn (pred, y)
70
+ end
71
+ Flux. update! (opt, model, grads[1 ])
72
+ end
73
+ Flux. throttle (log_val (), 5 )
74
+ end
75
+ end
0 commit comments