-
Notifications
You must be signed in to change notification settings - Fork 0
/
Demo_cifar10.m
124 lines (95 loc) · 4.16 KB
/
Demo_cifar10.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
% Title : "Semi-NMF network for image classification",2019 Chinese Control Conference (CCC).
% Authors : Huang, H., Yang, Z., Liang, N., & Li, Z
% Affl. : Guangdong University of Technology, Guangzhou, China
% Email : libertyhhn@foxmail.com
%
% Be noted that we modified from the Deep Semi-NMF implementation provide by the authors as follows:
% "PCANet: A simple deep learning baseline for image classification?"
% T.-H. Chan, K. Jia, S. Gao, J. Lu, Z. Zeng, and Y. Ma,
% IEEE Trans. Image Processing, vol. 24, no. 12, pp. 5017-5032, Dec. 2015.
% PCANet code URL: http://mx.nthu.edu.tw/~tsunghan/Source%20codes.html
clear all; close all; clc;
addpath('./Utils');
addpath('./Liblinear');
make;
TrnSize = 10000;
ImgSize = 32;
%% Loading data from CIFAR10 (50000 training, 10000 testing)
DataPath = './cifar-10-batches-mat';
% CIFAR10 dataset can be downloaded at http://www.cs.toronto.edu/~kriz/cifar.html
TrnLabels = [];
TrnData = [];
for i = 1:5
load(fullfile(DataPath,['data_batch_' num2str(i) '.mat']));
TrnData = [TrnData, data'];
TrnLabels = [TrnLabels; labels];
end
load(fullfile(DataPath,'test_batch.mat'));
TestData = data';
TestLabels = labels;
ImgFormat = 'color'; %'gray'
TrnLabels = double(TrnLabels);
TestLabels = double(TestLabels);
%% For this demo, we subsample the Training and Testing sets
% plz comment out the following four lines for a complete test.
% when you want to do so, please ensure that your computer memory is more than 64GB.
% training linear SVM classifier on large amount of high dimensional data would
% requires lots of memory.
TrnData = TrnData(:,1:50:end); % sample around 1000 training samples
TrnLabels = TrnLabels(1:50:end); %
TestData = TestData(:,1:10:end); % sample around 1000 test samples
TestLabels = TestLabels(1:10:end);
%%%%%%%%%%%%%%%%%%%%%%%%
nTestImg = length(TestLabels);
%% SNnet parameters (they should be funed based on validation set; i.e., ValData & ValLabel)
SNnet.NumStages = 2;
SNnet.PatchSize = [5 5];
SNnet.NumFilters = [40 8];
SNnet.HistBlockSize = [8 8];
SNnet.BlkOverLapRatio = 0.5;
SNnet.Pyramid = [4 2 1];
fprintf('\n ====== SNnet Parameters ======= \n')
SNnet
%% SNnet Training with 10000 samples
fprintf('\n ====== SNnet Training ======= \n')
TrnData_ImgCell = mat2imgcell(double(TrnData),ImgSize,ImgSize,ImgFormat); % convert columns in TrnData to cells
tic;
[ftrain, V, BlkIdx] = SNnet_train(TrnData_ImgCell,SNnet,1,TrnLabels); % BlkIdx serves the purpose of learning block-wise DR projection matrix; e.g., WPCA
SNnet_TrnTime = toc;
%% SNnet hashing over histograms
c = 10;
fprintf('\n ====== Training Linear SVM Classifier ======= \n')
display(['now testing c = ' num2str(c) '...'])
tic;
models = train(TrnLabels, ftrain', ['-s 1 -c ' num2str(c) ' -q']); % we use linear SVM classifier (C = 10), calling liblinear library
LinearSVM_TrnTime = toc;
%% SNnet Feature Extraction and Testing
TestData_ImgCell = mat2imgcell(TestData,ImgSize,ImgSize,ImgFormat); % convert columns in TestData to cells
clear TestData;
fprintf('\n ====== PCANet Testing ======= \n')
nCorrRecog = 0;
RecHistory = zeros(nTestImg,1);
tic;
for idx = 1:1:nTestImg
ftest = SNnet_FeaExt(TestData_ImgCell(idx),V,SNnet); % extract a test feature using trained PCANet model
[xLabel_est, accuracy, decision_values] = predict(TestLabels(idx),...
sparse(ftest'), models, '-q');
if xLabel_est == TestLabels(idx)
RecHistory(idx) = 1;
nCorrRecog = nCorrRecog + 1;
end
if 0==mod(idx,nTestImg/1000);
fprintf('Accuracy up to %d tests is %.2f%%; taking %.2f secs per testing sample on average. \n',...
[idx 100*nCorrRecog/idx toc/idx]);
end
TestData_ImgCell{idx} = [];
end
Averaged_TimeperTest = toc/nTestImg;
Accuracy = nCorrRecog/nTestImg;
ErRate = 1 - Accuracy;
%% Results display
fprintf('\n ===== Results of SNnet, followed by a linear SVM classifier =====');
fprintf('\n SNnet training time: %.2f secs.', SNnet_TrnTime);
fprintf('\n Linear SVM training time: %.2f secs.', LinearSVM_TrnTime);
fprintf('\n Testing error rate: %.2f%%', 100*ErRate);
fprintf('\n Average testing time %.2f secs per test sample. \n\n',Averaged_TimeperTest);