-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGeneticProgramming.java
108 lines (101 loc) · 4.11 KB
/
GeneticProgramming.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class GeneticProgramming {
private List<GPNode> population;
private List<double[]> trainingData;
private List<double[]> testData;
private Random random;
private GPNode bestIndividual;
private static final int MAX_TREE_DEPTH = 4;
private static final int POPULATION_SIZE = 100;
private static final int TOURNAMENT_SIZE = 5;
private static final int NUM_GENERATIONS = 50;
private static final double CROSSOVER_RATE = 0.7;
private static final double MUTATION_RATE = 0.1;
public GeneticProgramming(List<double[]> trainingData, List<double[]> testData) {
this.trainingData = trainingData;
this.testData = testData;
this.random = new Random(42);
this.population = new ArrayList<>();
for (int i = 0; i < POPULATION_SIZE; i++) {
this.population.add(GPNode.generateRandomTree(MAX_TREE_DEPTH));
}
}
public void evolve() {
ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
for (int generation = 0; generation < NUM_GENERATIONS; generation++) {
List<GPNode> newPopulation = new ArrayList<>();
List<Future<GPNode>> futures = new ArrayList<>();
for (int i = 0; i < POPULATION_SIZE; i++) {
futures.add(executor.submit(() -> {
GPNode parent1 = tournamentSelection();
GPNode parent2 = tournamentSelection();
GPNode offspring = parent1.crossover(parent2, CROSSOVER_RATE);
if (random.nextDouble() < MUTATION_RATE) {
offspring.mutate(MAX_TREE_DEPTH);
}
return offspring;
}));
}
try {
for (Future<GPNode> future : futures) {
newPopulation.add(future.get());
}
} catch (Exception e) {
e.printStackTrace();
}
this.population = newPopulation;
GPNode best = getBestIndividual();
System.out.println("Generation " + generation + " Accuracy: " + (1 - best.evaluateFitness(trainingData) / trainingData.size()));
}
executor.shutdown();
}
public GPNode tournamentSelection() {
GPNode best = null;
for (int i = 0; i < TOURNAMENT_SIZE; i++) {
GPNode individual = population.get(random.nextInt(population.size()));
if (best == null || individual.evaluateFitness(trainingData) < best.evaluateFitness(trainingData)) {
best = individual;
}
}
return best;
}
public GPNode getBestIndividual() {
GPNode best = null;
for (GPNode individual : population) {
if (best == null || individual.evaluateFitness(trainingData) < best.evaluateFitness(trainingData)) {
best = individual;
}
}
this.bestIndividual = best;
return best;
}
public void printMetrics(GPNode individual) {
double TP = 0, TN = 0, FP = 0, FN = 0;
for (double[] instance : testData) {
double actual = instance[instance.length - 1];
double predicted = individual.evaluate(instance);
predicted = predicted > 0.5 ? 1 : 0;
if (actual == 1) {
if (predicted == 1) TP++;
else FN++;
} else {
if (predicted == 1) FP++;
else TN++;
}
}
double accuracy = (TP + TN) / testData.size();
double specificity = TN / (TN + FP);
double sensitivity = TP / (TP + FN);
double precision = TP / (TP + FP);
double fMeasure = 2 * (precision * sensitivity) / (precision + sensitivity);
System.out.println("Test Accuracy: " + accuracy*100 + "%");
System.out.println("Test Specificity: " + specificity);
System.out.println("Test Sensitivity: " + sensitivity);
System.out.println("Test F-measure: " + fMeasure);
}
}