-
Notifications
You must be signed in to change notification settings - Fork 7
/
cnntrain_combined.m
88 lines (80 loc) · 6.47 KB
/
cnntrain_combined.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function [losses1, losses2, modelparas1, modelparas2] = cnntrain_combined(opts, arch, modelparas1, modelparas2, train_x, train_y, test_x, test_y)
%CNNTRAIN
%
training_size = size(train_x,3);
assert(mod(training_size, opts.batchsize)==0, 'numbatches not integer');
numbatches = training_size/opts.batchsize;
weights_inc_last1.hid2out_weights_inc = zeros(size(modelparas1.hid2out_weights));
weights_inc_last1.featvec2hid_weights_inc = zeros(size(modelparas1.featvec2hid_weights));
weights_inc_last1.filters2_inc = cell(arch.numfilters1, 1);
for i = 1:arch.numfilters1
weights_inc_last1.filters2_inc{i} = zeros(size(modelparas1.filters2{i}));
end
weights_inc_last1.filters1_inc = zeros(size(modelparas1.filters1));
weights_inc_last2.hid2out_weights_inc = zeros(size(modelparas2.hid2out_weights));
weights_inc_last2.featvec2hid_weights_inc = zeros(size(modelparas2.featvec2hid_weights));
weights_inc_last2.filters2_inc = cell(arch.numfilters1, 1);
for i = 1:arch.numfilters1
weights_inc_last2.filters2_inc{i} = zeros(size(modelparas2.filters2{i}));
end
weights_inc_last2.filters1_inc = zeros(size(modelparas2.filters1));
losses1 = zeros(numbatches*opts.numepochs, 1);
losses2 = zeros(numbatches*opts.numepochs, 1);
for i = 1:opts.numepochs
tic;
randinds = randperm(training_size);
for j = 1:numbatches
batch_x = train_x(:,:,randinds((j-1)*opts.batchsize+1:j*opts.batchsize));
batch_y = train_y(:, randinds((j-1)*opts.batchsize+1:j*opts.batchsize));
arch.poolstyle = 'mean';
[netstates1, CE_loss1] = cnnff(modelparas1, arch, batch_x, batch_y);
losses1((i-1)*numbatches + j) = CE_loss1;
fprintf('Mean pooling, Epoch %d/%d, batch %d, ce loss %f\n',i, opts.numepochs, j, CE_loss1);
grads1 = cnnbp(arch, modelparas1, netstates1, batch_x, batch_y);
modelparas1 = cnnapplygrads(opts, arch, modelparas1, grads1, weights_inc_last1);
arch.poolstyle = 'max';
[netstates2, CE_loss2] = cnnff(modelparas2, arch, batch_x, batch_y);
losses2((i-1)*numbatches + j) = CE_loss2;
fprintf('Max pooling, Epoch %d/%d, batch %d, ce loss %f\n',i, opts.numepochs, j, CE_loss2);
grads2 = cnnbp(arch, modelparas2, netstates2, batch_x, batch_y);
modelparas2 = cnnapplygrads(opts, arch, modelparas2, grads2, weights_inc_last2);
end
toc;
if mod(i,10)==0
[mean_err1, max_err1, combined_err1] = cnntest_combined(arch, modelparas1, modelparas2, test_x(:,:,1:5000), test_y(:,1:5000));
[mean_err2, max_err2, combined_err2] = cnntest_combined(arch, modelparas1, modelparas2, test_x(:,:,5000:10000), test_y(:,5000:10000));
mean_err = 0.5*(mean_err1+mean_err2); max_err = 0.5*(max_err1+max_err2); combined_err = 0.5*(combined_err1+combined_err2);
[train_mean_err1, train_max_err1, train_combined_err1] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,1:5000), train_y(:,1:5000));
[train_mean_err2, train_max_err2, train_combined_err2] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,5001:10000), train_y(:,5001:10000));
[train_mean_err3, train_max_err3, train_combined_err3] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,10001:15000), train_y(:,10001:15000));
[train_mean_err4, train_max_err4, train_combined_err4] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,15001:20000), train_y(:,15001:20000));
[train_mean_err5, train_max_err5, train_combined_err5] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,20001:25000), train_y(:,20001:25000));
[train_mean_err6, train_max_err6, train_combined_err6] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,25001:30000), train_y(:,25001:30000));
[train_mean_err7, train_max_err7, train_combined_err7] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,30001:35000), train_y(:,30001:35000));
[train_mean_err8, train_max_err8, train_combined_err8] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,35001:40000), train_y(:,35001:40000));
[train_mean_err9, train_max_err9, train_combined_err9] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,40001:45000), train_y(:,40001:45000));
[train_mean_err10, train_max_err10, train_combined_err10] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,45001:50000), train_y(:,45001:50000));
[train_mean_err11, train_max_err11, train_combined_err11] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,50001:55000), train_y(:,50001:55000));
[train_mean_err12, train_max_err12, train_combined_err12] = cnntest_combined(arch, modelparas1, modelparas2, train_x(:,:,55001:60000), train_y(:,55001:60000));
train_mean_err=1/12*(train_mean_err1+train_mean_err2+train_mean_err3+train_mean_err4+train_mean_err5+train_mean_err6+train_mean_err7+train_mean_err8+train_mean_err9+train_mean_err10+train_mean_err11+train_mean_err12);
train_max_err=1/12*(train_max_err1+train_max_err2+train_max_err3+train_max_err4+train_max_err5+train_max_err6+train_max_err7+train_max_err8+train_max_err9+train_max_err10+train_max_err11+train_max_err12);
train_combined_err = 1/12*(train_combined_err1+train_combined_err2+train_combined_err3+train_combined_err4+train_combined_err5+train_combined_err6+train_combined_err7+train_combined_err8+train_combined_err9+train_combined_err10+train_combined_err11+train_combined_err12);
fileID = fopen('./results/combined_result.txt','a');
fprintf(fileID, '======================================================\n');
fprintf(fileID, 'numfilters1: %d\n',arch.numfilters1);
fprintf(fileID, 'numfilters2: %d\n',arch.numfilters2);
fprintf(fileID, 'hiddim: %d\n',arch.hiddim);
fprintf(fileID, 'number of epochs: %d\n',i);
fprintf(fileID, 'learning rate: %f\n',opts.learningrate);
fprintf(fileID, 'momentum: %f\n',opts.momentum);
fprintf(fileID, 'mean pooling test error: %f\n',mean_err);
fprintf(fileID, 'max pooling test error: %f\n',max_err);
fprintf(fileID, 'combined pooling test error: %f\n',combined_err);
fprintf(fileID, 'mean pooling train error: %f\n',train_mean_err);
fprintf(fileID, 'max pooling train error: %f\n',train_max_err);
fprintf(fileID, 'combined pooling train error: %f\n',train_combined_err);
% fprintf(fileID, 'mean pooling training time: %f\n',tElapsed1);
% fprintf(fileID, 'max pooling training time: %f\n',tElapsed2);
fclose(fileID);
end
end