-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtraining_script.m
76 lines (61 loc) · 2.7 KB
/
training_script.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
inputs = x2;
targets = y2;
% Create a Fitting Network
hiddenLayerSize = 10;
net = fitnet(hiddenLayerSize);
% Choose Input and Output Pre/Post-Processing Functions
% For a list of all processing functions type: help nnprocess
net.inputs{1}.processFcns = {'removeconstantrows','mapminmax'};
net.outputs{2}.processFcns = {'removeconstantrows','mapminmax'};
% Setup Division of Data for Training, Validation, Testing
% For a list of all data division functions type: help nndivide
net.divideFcn = 'dividerand'; % Divide data randomly
net.divideMode = 'sample'; % Divide up every sample
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
% For help on training function 'trainlm' type: help trainlm
% For a list of all training functions type: help nntrain
% trainbfg - BFGS quasi-Newton backpropagation.
% trainbr - Bayesian Regulation backpropagation.
% traincgb - Conjugate gradient backpropagation with Powell-Beale restarts.
% traincgf - Conjugate gradient backpropagation with Fletcher-Reeves updates.
% traincgp - Conjugate gradient backpropagation with Polak-Ribiere updates.
% traingd - Gradient descent backpropagation.
% traingda - Gradient descent with adaptive lr backpropagation.
% traingdm - Gradient descent with momentum.
% traingdx - Gradient descent w/momentum & adaptive lr backpropagation.
% trainlm - Levenberg-Marquardt backpropagation.
% trainoss - One step secant backpropagation.
% trainrp - RPROP backpropagation.
% trainscg - Scaled conjugate gradient backpropagation.
net.trainFcn = 'trainlm'; % Levenberg-Marquardt
% Choose a Performance Function
% For a list of all performance functions type: help nnperformance
net.performFcn = 'mse'; % Mean squared error
% Choose Plot Functions
% For a list of all plot functions type: help nnplot
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', ...
'plotregression', 'plotfit'};
% Train the Network
[net,tr] = train(net,inputs,targets);
% Test the Network
outputs = net(inputs);
errors = gsubtract(targets,outputs);
performance = perform(net,targets,outputs)
% Recalculate Training, Validation and Test Performance
trainTargets = targets .* tr.trainMask{1};
valTargets = targets .* tr.valMask{1};
testTargets = targets .* tr.testMask{1};
trainPerformance = perform(net,trainTargets,outputs)
valPerformance = perform(net,valTargets,outputs)
testPerformance = perform(net,testTargets,outputs)
% View the Network
view(net)
% Plots
% Uncomment these lines to enable various plots.
%figure, plotperform(tr)
%figure, plottrainstate(tr)
%figure, plotfit(net,inputs,targets)
%figure, plotregression(targets,outputs)
%figure, ploterrhist(errors)