-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathTestAll.java
182 lines (153 loc) · 7 KB
/
TestAll.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import selection.*;
import java.util.Set;
import static org.junit.Assert.assertTrue;
/**
* Runs the four selection methods against a given
* dataset. This is a useful class for running all the
* code at once and checking the output.
*
* The only tests this actually does is checks
* the size of the subsets returned from the numFeatures
* methods is less than or equal to the specified size.
*/
public class TestAll {
// File of instances to use
private final String FILE_NAME = "res/musk.arff";
// Only specify this if you have a testing file, otherwise leave null and above file will be split
private final String TESTING_FILE = null;
// Maximum number of features to select
private final int MAX_FEATURES = 50;
// Maximum iterations to keep trying with no progression in subset accuracy
private final int MAX_ITERATIONS_WITHOUT_PROGRESS = 10;
// Whether or not to run the num feature tests
private boolean numFeatureTests = false;
/***
* ===============
* SFS TESTS
* ===============
*/
@org.junit.Test
public void testSequentialForwardSelection() throws Exception {
System.out.println("-------------------");
System.out.println("Sequential forward selection");
FeatureSelection selector = generateSelector(Selection.SFS);
Set<Integer> selectedIndices = selector.select();
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
}
@org.junit.Test
public void testSequentialForwardSelectionNumfeatures() throws Exception {
if(!numFeatureTests) return;
System.out.println("-------------------");
System.out.println("Sequential forward selection for max " + MAX_FEATURES + " features");
FeatureSelection selector = generateSelector(Selection.SFS);
Set<Integer> selectedIndices = selector.select(MAX_FEATURES);
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
assertTrue(selectedIndices.size() <= MAX_FEATURES);
}
/***
* ===============
* SBS TESTS
* ===============
*/
@org.junit.Test
public void testSequentialBackwardSelection() throws Exception {
System.out.println("-------------------");
System.out.println("Sequential backward selection");
FeatureSelection selector = generateSelector(Selection.SBS);
Set<Integer> selectedIndices = selector.select();
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
}
@org.junit.Test
public void testSequentialBackwardSelectionNumfeatures() throws Exception {
if(!numFeatureTests) return;
System.out.println("-------------------");
System.out.println("Sequential backward selection for max " + MAX_FEATURES + " Features");
FeatureSelection selector = generateSelector(Selection.SBS);
Set<Integer> selectedIndices = selector.select(MAX_FEATURES);
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
assertTrue(selectedIndices.size() <= MAX_FEATURES);
}
/***
* ===============
* FLOATING TESTS
* ===============
*/
@org.junit.Test
public void testSequentialFloatingForwardSelection() throws Exception {
System.out.println("-------------------");
System.out.println("Sequential floating forward selection");
FeatureSelection selector = generateSelector(Selection.SFFS);
Set<Integer> selectedIndices = selector.select();
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
}
@org.junit.Test
public void testSequentialFloatingForwardSelectionNumFeatures() throws Exception {
if(!numFeatureTests) return;
System.out.println("-------------------");
System.out.println("Sequential floating forward selection for " + MAX_FEATURES + " features");
FeatureSelection selector = generateSelector(Selection.SFFS);
Set<Integer> selectedIndices = selector.select(MAX_FEATURES);
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
assertTrue(selectedIndices.size() <= MAX_FEATURES);
}
@org.junit.Test
public void testSequentialFloatingBackwardSelection() throws Exception {
System.out.println("-------------------");
System.out.println("Sequential backward floating selection");
FeatureSelection selector = generateSelector(Selection.SFBS);
Set<Integer> selectedIndices = selector.select();
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
}
@org.junit.Test
public void testSequentialFloatingBackwardSelectionNumFeatures() throws Exception {
if(!numFeatureTests) return;
System.out.println("-------------------");
System.out.println("Sequential backward floating selection for " + MAX_FEATURES + " features");
FeatureSelection selector = generateSelector(Selection.SFBS);
Set<Integer> selectedIndices = selector.select(MAX_FEATURES);
selector.compareTestingAccuracy(selectedIndices);
System.out.println("-------------------");
assertTrue(selectedIndices.size() <= MAX_FEATURES);
}
/***
* ===============
* HELPER METHODS
* ===============
*/
private FeatureSelection generateSelector(Selection method) throws Exception {
FeatureSelection selector = null;
switch (method){
case SBS:
selector = TESTING_FILE == null ? new SequentialBackwardSelection(FILE_NAME, MAX_ITERATIONS_WITHOUT_PROGRESS) : new SequentialBackwardSelection(FILE_NAME, TESTING_FILE, MAX_ITERATIONS_WITHOUT_PROGRESS);
break;
case SFS:
selector = TESTING_FILE == null ? new SequentialForwardSelection(FILE_NAME, MAX_ITERATIONS_WITHOUT_PROGRESS) : new SequentialForwardSelection(FILE_NAME, TESTING_FILE, MAX_ITERATIONS_WITHOUT_PROGRESS);
break;
case SFBS:
selector = TESTING_FILE == null ? new SequentialFloatingBackwardSelection(FILE_NAME, MAX_ITERATIONS_WITHOUT_PROGRESS) : new SequentialFloatingBackwardSelection(FILE_NAME, TESTING_FILE, MAX_ITERATIONS_WITHOUT_PROGRESS);
break;
case SFFS:
selector = TESTING_FILE == null ? new SequentialFloatingForwardSelection(FILE_NAME, MAX_ITERATIONS_WITHOUT_PROGRESS) : new SequentialFloatingForwardSelection(FILE_NAME, TESTING_FILE, MAX_ITERATIONS_WITHOUT_PROGRESS);
break;
}
// Special case for musk
if(FILE_NAME.equals("musk.arff")){
// There is a "giveaway" feature (molecule_name) which stores some class information
selector.removeAttribute(0);
}
return selector;
}
private enum Selection {
SFS,
SBS,
SFFS,
SFBS
}
}