-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSequentialFloatingBackwardSelection.java
133 lines (96 loc) · 5.16 KB
/
SequentialFloatingBackwardSelection.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package selection;
import java.util.HashSet;
import java.util.Set;
/**
* Performs Sequential Floating Backward Selection (SFBS)
* <p>
* - Starts with full set of features
* - Removes the "worst" feature
* - Performs SFS as long as the objective function increases
* - Goes back to step 2 until stopping criteria is met
*/
public class SequentialFloatingBackwardSelection extends FeatureSelection {
public SequentialFloatingBackwardSelection(String file, int maxIterationsWithoutProgress) throws Exception {
super(file, maxIterationsWithoutProgress);
}
public SequentialFloatingBackwardSelection(String training, String testing, int maxIterationsWithoutProgress) throws Exception {
super(training, testing, maxIterationsWithoutProgress);
}
public Set<Integer> select(int maxNumFeatures) throws Exception {
return select((noImprovement, size) -> size > maxNumFeatures || noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS, maxNumFeatures);
}
public Set<Integer> select() throws Exception {
return select((noImprovement, size) -> noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS);
}
public Set<Integer> select(Criteria criteria) throws Exception {
// Max features is all the features
return select(criteria, getNumFeatures());
}
private Set<Integer> select(Criteria criteria, int maxNumFeatures) throws Exception {
// To begin with all features are selected, so all the indices from 0..totalFeatures are remaining
Set<Integer> selectedFeatures = getAllFeatureIndices();
// Subset of only remaining features indices
Set<Integer> remainingFeatures = new HashSet<>();
// Keep track of the best solution, so we never get worse
double highestAccuracy = 0;
Set<Integer> bestSoFar = new HashSet<>();
double accuracy = objectiveFunction(selectedFeatures);
double lastAccuracy = accuracy;
// Keep track of the visited states to avoid getting stuck in an infinite loop
Set<Set<Integer>> visitedSubsets = new HashSet<Set<Integer>>();
visitedSubsets.add(new HashSet<>(selectedFeatures));
// Number of iterations with no improvement
int iterationsWithoutImprovement = 0;
printAccuracy(selectedFeatures.size(), accuracy);
while (criteria.evaluate(iterationsWithoutImprovement, selectedFeatures.size())) {
/* EXCLUDE THE WORST FEATURE */
int worstFeature = worst(selectedFeatures);
// No more valid features
if (worstFeature == -1) break;
// Remove the feature and add the feature back to our remaining features
selectedFeatures.remove(worstFeature);
remainingFeatures.add(worstFeature);
// Note that we have been to this state
visitedSubsets.add(new HashSet<>(selectedFeatures));
// This will be our point of comparison when adding features
double accuracyBeforeAddition = objectiveFunction(selectedFeatures);
printAccuracy(selectedFeatures.size(), accuracyBeforeAddition);
/* INCLUDE THE BEST FEATURES */
// Now add the best features, while we are improving
while (true) {
int bestFeature = best(selectedFeatures, remainingFeatures);
// No more valid features
if (bestFeature == -1) break;
selectedFeatures.add(bestFeature);
remainingFeatures.remove(bestFeature);
double accuracyAfterAddition = objectiveFunction(selectedFeatures);
printAccuracy(selectedFeatures.size(), accuracyAfterAddition);
// If the accuracy did not improve or we have been to this state, undo this step and continue removing features
if (lessThan(accuracyAfterAddition, accuracyBeforeAddition) || visitedSubsets.contains(selectedFeatures)) {
selectedFeatures.remove(bestFeature);
remainingFeatures.add(bestFeature);
break;
}
// Note that we have been to this state
visitedSubsets.add(new HashSet<>(selectedFeatures));
// This will be our new point of comparison for the next addition to the selected features
accuracyBeforeAddition = accuracyAfterAddition;
}
accuracy = objectiveFunction(selectedFeatures);
// If the accuracy is higher than our previous best, or the same with less features and its a valid size (<= maxFeatures)
if ((greaterThan(accuracy, highestAccuracy) || (equalTo(accuracy, highestAccuracy) && selectedFeatures.size() < bestSoFar.size()))
&& selectedFeatures.size() <= maxNumFeatures) {
highestAccuracy = accuracy;
// Save our best set
bestSoFar = new HashSet<>(selectedFeatures);
}
if (lessThanOrEqualTo(accuracy, lastAccuracy)) {
iterationsWithoutImprovement++;
} else {
iterationsWithoutImprovement = 0;
}
lastAccuracy = accuracy;
}
return bestSoFar;
}
}