-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSequentialForwardSelection.java
79 lines (59 loc) · 2.67 KB
/
SequentialForwardSelection.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package selection;
import java.util.HashSet;
import java.util.Set;
/**
* Performs Sequential Floating Forward Selection (SFFS)
* <p>
* - Starts with empty set of features
* - Adds the "best" feature until stopping criteria is met
*/
public class SequentialForwardSelection extends FeatureSelection {
public SequentialForwardSelection(String file, int maxIterationsWithoutProgress) throws Exception {
super(file, maxIterationsWithoutProgress);
}
public SequentialForwardSelection(String training, String testing, int maxIterationsWithoutProgress) throws Exception {
super(training, testing, maxIterationsWithoutProgress);
}
public Set<Integer> select(int maxNumFeatures) throws Exception {
return select((accuracy, size) -> size < maxNumFeatures);
}
public Set<Integer> select() throws Exception {
return select((noImprovement, size) -> noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS);
}
public Set<Integer> select(Criteria criteria) throws Exception {
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
Set<Integer> remainingFeatures = getAllFeatureIndices();
// Subset of only selected features indices
Set<Integer> selectedFeatures = new HashSet<>();
// Keep track of the best solution, so we never get worse
double highestAccuracy = 0;
Set<Integer> bestSoFar = new HashSet<>();
double accuracy = objectiveFunction(selectedFeatures);
double lastAccuracy = accuracy;
printAccuracy(selectedFeatures.size(), accuracy);
// Number of iterations with no improvement
double noImprovement = 0;
while (criteria.evaluate(noImprovement, selectedFeatures.size())) {
int feature = best(selectedFeatures, remainingFeatures);
// No more valid features
if (feature == -1) break;
selectedFeatures.add(feature);
// Remove the feature so we do not keep selecting the same one
remainingFeatures.remove(feature);
accuracy = objectiveFunction(selectedFeatures);
if (greaterThan(accuracy, highestAccuracy)) {
highestAccuracy = accuracy;
// Make a copy, so we don't accidentally modify the best subset
bestSoFar = new HashSet<>(selectedFeatures);
}
printAccuracy(selectedFeatures.size(), accuracy);
if (lessThanOrEqualTo(accuracy, lastAccuracy)) {
noImprovement++;
} else {
noImprovement = 0;
}
lastAccuracy = accuracy;
}
return bestSoFar;
}
}