You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried the following code. I have vanilla CuArray and Flux.
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
using CuArrays
# Classify MNIST digits with a simple multi-layer-perceptron
imgs = MNIST.images()
# Stack images into one large batch
X = hcat(float.(reshape.(imgs, :))...) |> gpu
labels = MNIST.labels()
# One-hot-encode the labels
Y = onehotbatch(labels, 0:9) |> gpu
m = Chain(
Dense(28^2, 32, relu),
Dense(32, 10),
softmax) |> gpu
loss(x, y) = crossentropy(m(x), y)
function accuracy(x, y)
a = onecold((m(x)))
b = onecold(y) |> gpu #### If this is not there, it beceomes a julia array
return mean(a .== b)
end
dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM()
Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))
accuracy(X, Y)
The text was updated successfully, but these errors were encountered:
I tried the following code. I have vanilla
CuArray
andFlux
.The text was updated successfully, but these errors were encountered: