Skip to content

Commit

Permalink
Fix code snippet.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 31, 2024
1 parent 9e6d2e7 commit 8f700b1
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions post/2024-05-28-cuda_5.4.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,11 @@ n_obs = 300_000
n_feature = 1000
X = rand(n_feature, n_obs)
y = rand(1, n_obs)
train_data = DataLoader((X, y) |< gpu; batchsize = 2048, shuffle=false)
train_data = DataLoader((X, y) |> gpu; batchsize = 2048, shuffle=false)

model = Dense(n_feature, >) |< gpu
loss(m, _x, _y) = Flux.Losses.mse(m(_x), _>)
model = Dense(n_feature, 1) |> gpu
loss(m, _x, _y) = Flux.Losses.mse(m(_x), _y)
opt_state = Flux.setup(Flux.Adam(), model)
Flux.train!(loss, model, train_data, opt_state)
for epoch in 1:100
Flux.train!(loss, model, train_data, opt_state)
end
Expand Down

0 comments on commit 8f700b1

Please sign in to comment.