diff --git a/.gitignore b/.gitignore index f615e84..ff4d4f3 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ datasets/ # Ignore log folders log/ +logs/ # Ignore emacs saved files *.*~ diff --git a/train-a-digit-classifier/train-on-mnist-cuda.lua b/train-a-digit-classifier/train-on-mnist-cuda.lua new file mode 100644 index 0000000..0200455 --- /dev/null +++ b/train-a-digit-classifier/train-on-mnist-cuda.lua @@ -0,0 +1,363 @@ +---------------------------------------------------------------------- +-- This script shows how to train different models on the MNIST +-- dataset, using multiple optimization techniques (SGD, LBFGS) +-- +-- This script demonstrates a classical example of training +-- well-known models (convnet, MLP, logistic regression) +-- on a 10-class classification problem. +-- +-- It illustrates several points: +-- 1/ description of the model +-- 2/ choice of a loss function (criterion) to minimize +-- 3/ creation of a dataset as a simple Lua table +-- 4/ description of training and test procedures +-- +-- Clement Farabet +---------------------------------------------------------------------- + +require 'torch' +require 'cutorch' +require 'cunn' +require 'nnx' +require 'optim' +require 'image' +require 'dataset-mnist' +require 'pl' +require 'paths' + +---------------------------------------------------------------------- +-- parse command-line options +-- +local opt = lapp[[ + -s,--save (default "logs") subdirectory to save logs + -n,--network (default "") reload pretrained network + -m,--model (default "convnet") type of model tor train: convnet | mlp | linear + -f,--full use the full dataset + -p,--plot plot while training + -o,--optimization (default "SGD") optimization: SGD | LBFGS + -r,--learningRate (default 0.05) learning rate, for SGD only + -b,--batchSize (default 10) batch size + -m,--momentum (default 0) momentum, for SGD only + -i,--maxIter (default 3) maximum nb of iterations per batch, for LBFGS + --coefL1 (default 0) L1 penalty on the weights + --coefL2 (default 0) L2 penalty on the weights + -t,--threads (default 4) number of threads +]] + +-- fix seed +torch.manualSeed(1) + +-- threads +torch.setnumthreads(opt.threads) +print(' set nb of threads to ' .. torch.getnumthreads()) + +-- use floats, for SGD +if opt.optimization == 'SGD' then + torch.setdefaulttensortype('torch.FloatTensor') +end + +-- batch size? +if opt.optimization == 'LBFGS' and opt.batchSize < 100 then + error('LBFGS should not be used with small mini-batches; 1000 is recommended') +end + +---------------------------------------------------------------------- +-- define model to train +-- on the 10-class classification problem +-- +classes = {'1','2','3','4','5','6','7','8','9','10'} + +-- geometry: width and height of input images +geometry = {32,32} + +if opt.network == '' then + -- define model to train + model = nn.Sequential() + + if opt.model == 'convnet' then + ------------------------------------------------------------ + -- convolutional network + ------------------------------------------------------------ + -- stage 1 : mean suppresion -> filter bank -> squashing -> max pooling + model:add(nn.SpatialConvolution(1, 32, 5, 5)) + model:add(nn.Tanh()) + model:add(nn.SpatialMaxPooling(3, 3, 3, 3)) + -- stage 2 : mean suppresion -> filter bank -> squashing -> max pooling + model:add(nn.SpatialConvolution(32, 64, 5, 5)) + model:add(nn.Tanh()) + model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) + -- stage 3 : standard 2-layer MLP: + model:add(nn.Reshape(64*2*2)) + model:add(nn.Linear(64*2*2, 200)) + model:add(nn.Tanh()) + model:add(nn.Linear(200, #classes)) + ------------------------------------------------------------ + + elseif opt.model == 'mlp' then + ------------------------------------------------------------ + -- regular 2-layer MLP + ------------------------------------------------------------ + model:add(nn.Reshape(1024)) + model:add(nn.Linear(1024, 2048)) + model:add(nn.Tanh()) + model:add(nn.Linear(2048,#classes)) + ------------------------------------------------------------ + + elseif opt.model == 'linear' then + ------------------------------------------------------------ + -- simple linear model: logistic regression + ------------------------------------------------------------ + model:add(nn.Reshape(1024)) + model:add(nn.Linear(1024,#classes)) + ------------------------------------------------------------ + + else + print('Unknown model type') + cmd:text() + error() + end +else + print(' reloading previously trained network') + model = torch.load(opt.network) +end + +-- verbose +print(' using model:') +print(model) + +---------------------------------------------------------------------- +-- loss function: negative log-likelihood +-- +model:add(nn.LogSoftMax()) +model2 = model:cuda() +model = nn.Sequential() +model:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor')) +model:add(model2) +model:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor')) +criterion = nn.ClassNLLCriterion() + +-- retrieve parameters and gradients +parameters,gradParameters = model:getParameters() + +---------------------------------------------------------------------- +-- get/create dataset +-- +if opt.full then + nbTrainingPatches = 60000 + nbTestingPatches = 10000 +else + nbTrainingPatches = 2000 + nbTestingPatches = 1000 + print(' only using 2000 samples to train quickly (use flag -full to use 60000 samples)') +end + +-- create training set and normalize +trainData = mnist.loadTrainSet(nbTrainingPatches, geometry) +trainData:normalizeGlobal(mean, std) + +-- create test set and normalize +testData = mnist.loadTestSet(nbTestingPatches, geometry) +testData:normalizeGlobal(mean, std) + +---------------------------------------------------------------------- +-- define training and testing functions +-- + +-- this matrix records the current confusion across classes +confusion = optim.ConfusionMatrix(classes) + +-- log results to files +trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) +testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) + +-- training function +function train(dataset) + -- epoch tracker + epoch = epoch or 1 + + -- local vars + local time = sys.clock() + + -- do one epoch + print(' on training set:') + print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') + for t = 1,dataset:size(),opt.batchSize do + -- create mini batch + local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) + local targets = torch.Tensor(opt.batchSize) + local k = 1 + for i = t,math.min(t+opt.batchSize-1,dataset:size()) do + -- load new sample + local sample = dataset[i] + local input = sample[1]:clone() + local _,target = sample[2]:clone():max(1) + target = target:squeeze() + inputs[k] = input + targets[k] = target + k = k + 1 + end + + -- create closure to evaluate f(X) and df/dX + local feval = function(x) + -- just in case: + collectgarbage() + + -- get new parameters + if x ~= parameters then + parameters:copy(x) + end + + -- reset gradients + gradParameters:zero() + + -- evaluate function for complete mini batch + local outputs = model:forward(inputs) + local f = criterion:forward(outputs, targets) + + -- estimate df/dW + local df_do = criterion:backward(outputs, targets) + model:backward(inputs, df_do) + + -- penalties (L1 and L2): + if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then + -- locals: + local norm,sign= torch.norm,torch.sign + + -- Loss: + f = f + opt.coefL1 * norm(parameters,1) + f = f + opt.coefL2 * norm(parameters,2)^2/2 + + -- Gradients: + gradParameters:add( sign(parameters):mul(opt.coefL1) + parameters:clone():mul(opt.coefL2) ) + end + + -- update confusion + for i = 1,opt.batchSize do + confusion:add(outputs[i], targets[i]) + end + + -- return f and df/dX + return f,gradParameters + end + + -- optimize on current mini-batch + if opt.optimization == 'LBFGS' then + + -- Perform LBFGS step: + lbfgsState = lbfgsState or { + maxIter = opt.maxIter, + lineSearch = optim.lswolfe + } + optim.lbfgs(feval, parameters, lbfgsState) + + -- disp report: + print('LBFGS step') + print(' - progress in batch: ' .. t .. '/' .. dataset:size()) + print(' - nb of iterations: ' .. lbfgsState.nIter) + print(' - nb of function evalutions: ' .. lbfgsState.funcEval) + + elseif opt.optimization == 'SGD' then + + -- Perform SGD step: + sgdState = sgdState or { + learningRate = opt.learningRate, + momentum = opt.momentum, + learningRateDecay = 5e-7 + } + optim.sgd(feval, parameters, sgdState) + + -- disp progress + xlua.progress(t, dataset:size()) + + else + error('unknown optimization method') + end + end + + -- time taken + time = sys.clock() - time + time = time / dataset:size() + print(" time to learn 1 sample = " .. (time*1000) .. 'ms') + + -- print confusion matrix + print(confusion) + trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} + confusion:zero() + + -- save/log current net + local filename = paths.concat(opt.save, 'mnist.net') + os.execute('mkdir -p ' .. sys.dirname(filename)) + if paths.filep(filename) then + os.execute('mv ' .. filename .. ' ' .. filename .. '.old') + end + print(' saving network to '..filename) + -- torch.save(filename, model) + + -- next epoch + epoch = epoch + 1 +end + +-- test function +function test(dataset) + -- local vars + local time = sys.clock() + + -- test over given dataset + print(' on testing Set:') + for t = 1,dataset:size(),opt.batchSize do + -- disp progress + xlua.progress(t, dataset:size()) + + -- create mini batch + local inputs = torch.Tensor(opt.batchSize,1,geometry[1],geometry[2]) + local targets = torch.Tensor(opt.batchSize) + local k = 1 + for i = t,math.min(t+opt.batchSize-1,dataset:size()) do + -- load new sample + local sample = dataset[i] + local input = sample[1]:clone() + local _,target = sample[2]:clone():max(1) + target = target:squeeze() + inputs[k] = input + targets[k] = target + k = k + 1 + end + + -- test samples + local preds = model:forward(inputs) + + -- confusion: + for i = 1,opt.batchSize do + confusion:add(preds[i], targets[i]) + end + end + + -- timing + time = sys.clock() - time + time = time / dataset:size() + print(" time to test 1 sample = " .. (time*1000) .. 'ms') + + -- print confusion matrix + print(confusion) + testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} + confusion:zero() +end + +---------------------------------------------------------------------- +-- and train! +-- + +-- local vars +while true do + -- train/test + train(trainData) + test(testData) + + -- plot errors + if opt.plot then + trainLogger:style{['% mean class accuracy (train set)'] = '-'} + testLogger:style{['% mean class accuracy (test set)'] = '-'} + trainLogger:plot() + testLogger:plot() + end +end