Skip to content

Commit fc9ba3e

Browse files
committed
updated flux example reproducing the bug
1 parent 7678145 commit fc9ba3e

File tree

1 file changed

+34
-48
lines changed

1 file changed

+34
-48
lines changed

examples/Flux.jl

+34-48
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#For a Colab of this example, goto https://colab.research.google.com/drive/1xfUsBn9GEqbRjBF-UX_jnGjHZNtNsMae
22
using TensorBoardLogger
3-
using Flux
3+
using Flux, ChainRulesCore
44
using Logging
55
using MLDatasets
66
using Statistics
7-
using CuArrays
87

98
#create tensorboard logger
109
logdir = "content/log"
@@ -28,62 +27,49 @@ model = Chain(
2827
Dense(28^2, 32, relu),
2928
Dense(32, 10),
3029
softmax
31-
)
30+
) |> gpu
3231

33-
loss(x, y) = Flux.crossentropy(model(x), y)
32+
loss_fn(pred, y) = Flux.crossentropy(pred, y)
3433

35-
accuracy(x, y) = mean(Flux.onecold(model(x) |> cpu) .== Flux.onecold(y |> cpu))
34+
accuracy(pred, y) = mean(Flux.onecold(pred |> cpu) .== Flux.onecold(y |> cpu))
3635

37-
opt = ADAM()
36+
opt = Flux.setup(ADAM(), model)
3837

39-
traindata = permutedims(reshape(traindata, (28, 28, 60000, 1)), (1, 2, 4, 3));
40-
testdata = permutedims(reshape(testdata, (28, 28, 10000, 1)), (1, 2, 4, 3));
38+
traindata = reshape(traindata, (28, 28, 1, 60000));
39+
testdata = reshape(testdata, (28, 28, 1, 10000));
4140
trainlabels = Flux.onehotbatch(trainlabels, collect(0:9));
4241
testlabels = Flux.onehotbatch(testlabels, collect(0:9));
4342

44-
#function to get dictionary of model parameters
45-
function fill_param_dict!(dict, m, prefix)
46-
if m isa Chain
47-
for (i, layer) in enumerate(m.layers)
48-
fill_param_dict!(dict, layer, prefix*"layer_"*string(i)*"/"*string(layer)*"/")
49-
end
50-
else
51-
for fieldname in fieldnames(typeof(m))
52-
val = getfield(m, fieldname)
53-
if val isa AbstractArray
54-
val = vec(val)
55-
end
56-
dict[prefix*string(fieldname)] = val
57-
end
58-
end
43+
#functions to log information
44+
function log_train(pred, y)
45+
@info "train" loss=loss_fn(pred, y) acc=accuracy(pred, y)
5946
end
60-
61-
#function to log information after every epoch
62-
function TBCallback()
63-
param_dict = Dict{String, Any}()
64-
fill_param_dict!(param_dict, model, "")
65-
with_logger(logger) do
66-
@info "model" params=param_dict log_step_increment=0
67-
@info "train" loss=loss(traindata, trainlabels) acc=accuracy(traindata, trainlabels) log_step_increment=0
68-
@info "test" loss=loss(testdata, testlabels) acc=accuracy(testdata, testlabels)
69-
end
47+
function log_val()
48+
params_vec, _ = Flux.destructure(model)
49+
@info "train" model=params_vec log_step_increment=0
50+
@info "test" loss=loss_fn(model(testdata), testlabels) acc=accuracy(model(testdata), testlabels)
7051
end
7152

72-
minibatches = []
73-
batchsize = 100
74-
for i in range(1, stop = trainsize÷batchsize)
75-
lbound = (i-1)*batchsize+1
76-
ubound = min(trainsize, i*batchsize)
77-
push!(minibatches, (traindata[:, :, :, lbound:ubound], trainlabels[:, lbound:ubound]))
78-
end
53+
# trainloader = Flux.DataLoader((data=traindata, label=trainlabels), batchsize=100, shuffle=true, buffer=true, parallel=true) |> gpu ;
54+
# testloader = Flux.DataLoader((data=testdata, label=testlabels), batchsize=100, shuffle=false, buffer=true) |> gpu ;
7955

80-
Move data and model to gpu
81-
traindata = traindata |> gpu
82-
testdata = testdata |> gpu
83-
trainlabels = trainlabels |> gpu
84-
testlabels = testlabels |> gpu
85-
model = model |> gpu
86-
minibatches = minibatches |> gpu
56+
trainloader = Flux.DataLoader((data=traindata, label=trainlabels), batchsize=100) |> gpu ;
57+
testloader = Flux.DataLoader((data=testdata, label=testlabels), batchsize=100) |> gpu ;
8758

8859
#Train
89-
@Flux.epochs 15 Flux.train!(loss, params(model), minibatches, opt, cb = Flux.throttle(TBCallback, 5))
60+
with_logger(logger) do
61+
for epoch in 1:15
62+
println("epoch $epoch")
63+
for (x, y) in trainloader
64+
loss, grads = Flux.withgradient(model) do m
65+
pred = m(x)
66+
ChainRulesCore.ignore_derivatives() do
67+
log_train(pred, y)
68+
end
69+
loss_fn(pred, y)
70+
end
71+
Flux.update!(opt, model, grads[1])
72+
end
73+
Flux.throttle(log_val(), 5)
74+
end
75+
end

0 commit comments

Comments
 (0)