diff --git a/docs/src/ML/ml.md b/docs/src/ML/ml.md index 00a807e..8b5d400 100644 --- a/docs/src/ML/ml.md +++ b/docs/src/ML/ml.md @@ -3,32 +3,29 @@ [From the model zoo](https://github.com/FluxML/model-zoo/blob/master/mnist/mnist.jl) ```Julia -using Flux, MNIST, CuArrays -using Flux: onehotbatch, argmax, mse, throttle +using Flux, MLDatasets, CuArrays +using Flux: onehotbatch, argmax, crossentropy, throttle, onecold using Base.Iterators: repeated -x, y = traindata() -y = onehotbatch(y, 0:9) +x, y = x, y = MNIST.traindata() +x = reshape(x,28*28,:) |> gpu +y = onehotbatch(y, 0:9) |> gpu m = Chain( Dense(28^2, 32, σ), Dense(32, 10), softmax -) +) |> gpu -using CuArrays -# or CLArrays (you then need to use cl -x, y = cu(x), cu(y) -m = mapparams(cu, m) -loss(x, y) = mse(m(x), y) +loss(x, y) = crossentropy(m(x), y) dataset = repeated((x, y), 500) evalcb = () -> @show(loss(x, y)) -opt = SGD(params(m), 1) +opt = Descent() -Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10)) +Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10)) # Check the prediction for the first digit -argmax(m(x[:,1]), 0:9) == argmax(y[:,1], 0:9) +onecold(m(x[:,1]), 0:9) == onecold(y[:,1], 0:9) ```