Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation how to use the trained network #78

Open
MagicMuscleMan opened this issue May 28, 2022 · 7 comments
Open

Add documentation how to use the trained network #78

MagicMuscleMan opened this issue May 28, 2022 · 7 comments

Comments

@MagicMuscleMan
Copy link

All the examples in the documentation stop after the training. How to apply the trained model to new input data is missing or should be made more obvious.

@wardjm
Copy link

wardjm commented Jul 31, 2022

Did you figure it out?

@MagicMuscleMan
Copy link
Author

Did you figure it out?

Unfortunately not.

@wardjm
Copy link

wardjm commented Aug 1, 2022

Thanks! I will get it today out of necessity. I'll report back and maybe make a pull request.

@chriselrod
Copy link
Contributor

sc(x, p) would give the loss, and

sc_noloss = SimpleChains.remove_loss(sc)
sc(x, p)

should give the predicted value.

@wardjm
Copy link

wardjm commented Aug 1, 2022

Thanks! Here's my entire script heavily based on the examples. It's REPL, so you'll need to slap some prints in there to run it as a script.

using MLDatasets
using SimpleChains

xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32);
xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32);

size(xtest3)
# (28, 28, 60000)
extrema(ytrain0) # digits, 0,...,9
# (0, 9)

xtrain4 = reshape(xtrain3, 28, 28, 1, :);
xtest4 = reshape(xtest3, 28, 28, 1, :);
ytrain1 = UInt32.(ytrain0 .+ 1);
ytest1 = UInt32.(ytest0 .+ 1);

lenet = SimpleChain(
  (static(28), static(28), static(1)),
  SimpleChains.Conv(SimpleChains.relu, (5, 5), 6),
  SimpleChains.MaxPool(2, 2),
  SimpleChains.Conv(SimpleChains.relu, (5, 5), 16),
  SimpleChains.MaxPool(2, 2),
  Flatten(3),
  TurboDense(SimpleChains.relu, 120),
  TurboDense(SimpleChains.relu, 84),
  TurboDense(identity, 10),
)

lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1));

p = SimpleChains.init_params(lenet);

G = SimpleChains.alloc_threaded_grad(lenetloss);

SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10);

sc_noloss = SimpleChains.remove_loss(lenet)

# Do all:
pred = lenet(xtest4, p)

pred[:,1]

ytest[1]

# Do single:
m = xtest4[:,:,1,1]
m1 = reshape(m, 28,28,1,:)
pred1 = lenet(m1, p)

@mrazomej
Copy link

mrazomej commented Feb 4, 2023

I agree that this is incredibly important! I spent the last 2 hours trying to figure out the simplest and most important thing: How to use the trained network. I should have come to the issues from the beginning since it was apparent I would be one of many with this question. But this has to be the first thing in the documentation.

@Vilin97
Copy link

Vilin97 commented May 12, 2023

Technically, there is an example in the docs but it's very hidden away and I also did not find it before I knew exactly what to look for.

@chriselrod , would it make sense to add @wardjm 's example in the docs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants