Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(#11): Implementation of Randrom Forest #11 #30

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,10 @@ public double predict(double[] feature) {
}

private double predict(double[] feature, Node node) {
// If we are at a leaf node, return the most common label
if (node.isLeaf) {
return getMostCommonLabel(node.data);
}

// Else move to the next node
if (feature[node.splitFeature] < node.value) {
return predict(feature, node.left);
} else {
Expand Down Expand Up @@ -183,12 +181,9 @@ public double evaluate(double[][] features, double[][] labels) {
}
}

// Calculate accuracy: ratio of correct predictions to total predictions
double accuracy = (double) correctPredictions / features.length;

// Log the accuracy value (optional)
LOG.info("Model Accuracy: {}%", accuracy * 100);

return accuracy;
}

Expand Down
138 changes: 138 additions & 0 deletions lib/src/main/java/de/edux/ml/randomforest/RandomForest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package de.edux.ml.randomforest;

import de.edux.ml.decisiontree.DecisionTree;
import de.edux.ml.decisiontree.IDecisionTree;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;

public class RandomForest {
private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);

private final List<IDecisionTree> trees = new ArrayList<>();
private final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current();

public void train(int numTrees,
double[][] features,
double[][] labels,
int maxDepth,
int minSamplesSplit,
int minSamplesLeaf,
int maxLeafNodes,
int numberOfFeatures) {

ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());

List<Future<IDecisionTree>> futures = new ArrayList<>();

for (int i = 0; i < numTrees; i++) {
futures.add(executor.submit(() -> {
IDecisionTree tree = new DecisionTree();
Sample subsetSample = getRandomSubset(numberOfFeatures, features, labels);
tree.train(subsetSample.featureSamples(), subsetSample.labelSamples(), maxDepth, minSamplesSplit, minSamplesLeaf, maxLeafNodes);
return tree;
}));
}

for (Future<IDecisionTree> future : futures) {
try {
trees.add(future.get());
} catch (ExecutionException | InterruptedException e) {
LOG.error("Failed to train a decision tree. Thread: " +
Thread.currentThread().getName(), e);
}
}
executor.shutdown();
try {
if (!executor.awaitTermination(60, TimeUnit.SECONDS)) {
executor.shutdownNow();
}
} catch (InterruptedException ex) {
executor.shutdownNow();
Thread.currentThread().interrupt();
}
}

private Sample getRandomSubset(int numberOfFeatures, double[][] features, double[][] labels) {
if (numberOfFeatures > features.length) {
throw new IllegalArgumentException("Number of feature must be between 1 and amount of features");
}
double[][] subFeatures = new double[numberOfFeatures][];
double[][] subLabels = new double[numberOfFeatures][];
for (int i = 0; i < numberOfFeatures; i++) {
int randomIndex = threadLocalRandom.nextInt(numberOfFeatures);
subFeatures[i] = features[randomIndex];
subLabels[i] = labels[randomIndex];
}

return new Sample(subFeatures, subLabels);
}


public double predict(double[] feature) {
ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
List<Future<Double>> futures = new ArrayList<>();

for (IDecisionTree tree : trees) {
futures.add(executor.submit(() -> tree.predict(feature)));
}

Map<Double, Long> voteMap = new HashMap<>();
for (Future<Double> future : futures) {
try {
double prediction = future.get();
voteMap.merge(prediction, 1L, Long::sum);
} catch (InterruptedException | ExecutionException e) {
LOG.error("Failed to retrieve prediction from future task. Thread: " +
Thread.currentThread().getName(), e);
}
}

executor.shutdown();
try {
if (!executor.awaitTermination(60, TimeUnit.SECONDS)) {
executor.shutdownNow();
}
} catch (InterruptedException ex) {
executor.shutdownNow();
Thread.currentThread().interrupt();
}

return voteMap.entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.orElseThrow(() -> new RuntimeException("Failed to find the most common prediction"));
}


public double evaluate(double[][] features, double[][] labels) {
int correctPredictions = 0;
for (int i = 0; i < features.length; i++) {
double predictedLabel = predict(features[i]);
double actualLabel = getIndexOfHighestValue(labels[i]);
if (predictedLabel == actualLabel) {
correctPredictions++;
}
}
return (double) correctPredictions / features.length;
}

private double getIndexOfHighestValue(double[] labels) {
int maxIndex = 0;
double maxValue = labels[0];

for (int i = 1; i < labels.length; i++) {
if (labels[i] > maxValue) {
maxValue = labels[i];
maxIndex = i;
}
}
return maxIndex;
}

}
5 changes: 5 additions & 0 deletions lib/src/main/java/de/edux/ml/randomforest/Sample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package de.edux.ml.randomforest;

public record Sample(double[][] featureSamples, double[][] labelSamples) {

}
61 changes: 61 additions & 0 deletions lib/src/test/java/de/edux/ml/RandomForestTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package de.edux.ml;

import de.edux.data.provider.Penguin;
import de.edux.data.provider.SeabornDataProcessor;
import de.edux.data.provider.SeabornProvider;
import de.edux.ml.randomforest.RandomForest;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.net.URL;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertTrue;

class RandomForestTest {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.7;
private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv";
private static SeabornProvider seabornProvider;
@BeforeAll
static void setup() {
URL url = RandomForestTest.class.getClassLoader().getResource(CSV_FILE_PATH);
if (url == null) {
throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH);
}
File csvFile = new File(url.getPath());
var seabornDataProcessor = new SeabornDataProcessor();
var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<List<Penguin>> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO);
seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1));
}
@Test
void train() {
double[][] features = seabornProvider.getTrainFeatures();
double[][] labels = seabornProvider.getTrainLabels();

double[][] testFeatures = seabornProvider.getTestFeatures();
double[][] testLabels = seabornProvider.getTestLabels();

assertTrue(features.length > 0);
assertTrue(labels.length > 0);
assertTrue(testFeatures.length > 0);
assertTrue(testLabels.length > 0);

int numberOfTrees = 100;
int maxDepth = 24;
int minSampleSize = 2;
int minSamplesLeaf = 1;
int maxLeafNodes = 12;
int numFeatures = (int) Math.sqrt(features.length)*3;

RandomForest randomForest = new RandomForest();
randomForest.train(numberOfTrees, features, labels, maxDepth, minSampleSize, minSamplesLeaf, maxLeafNodes,numFeatures);
double accuracy = randomForest.evaluate(testFeatures, testLabels);
System.out.println(accuracy);
assertTrue(accuracy>0.7);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ void train() {
assertTrue(testLabels.length > 0);

IDecisionTree decisionTree = new DecisionTree();
decisionTree.train(features, labels, 10, 2, 1, 8);
int maxDepth = 10;
int minSampleSplit = 2;
int minSampleLeaf = 1;
int maxLeafNodes = 8;
decisionTree.train(features, labels, maxDepth, minSampleSplit, minSampleLeaf, maxLeafNodes);
double accuracy = decisionTree.evaluate(testFeatures, testLabels);
assertTrue(accuracy>0.7);
}
Expand Down
Loading