diff --git a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java index 0a08291..4d2793e 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java @@ -146,12 +146,10 @@ public double predict(double[] feature) { } private double predict(double[] feature, Node node) { - // If we are at a leaf node, return the most common label if (node.isLeaf) { return getMostCommonLabel(node.data); } - // Else move to the next node if (feature[node.splitFeature] < node.value) { return predict(feature, node.left); } else { @@ -183,12 +181,9 @@ public double evaluate(double[][] features, double[][] labels) { } } - // Calculate accuracy: ratio of correct predictions to total predictions double accuracy = (double) correctPredictions / features.length; - // Log the accuracy value (optional) LOG.info("Model Accuracy: {}%", accuracy * 100); - return accuracy; } diff --git a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java new file mode 100644 index 0000000..49f559e --- /dev/null +++ b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java @@ -0,0 +1,138 @@ +package de.edux.ml.randomforest; + +import de.edux.ml.decisiontree.DecisionTree; +import de.edux.ml.decisiontree.IDecisionTree; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.*; + +public class RandomForest { + private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class); + + private final List trees = new ArrayList<>(); + private final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current(); + + public void train(int numTrees, + double[][] features, + double[][] labels, + int maxDepth, + int minSamplesSplit, + int minSamplesLeaf, + int maxLeafNodes, + int numberOfFeatures) { + + ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + + List> futures = new ArrayList<>(); + + for (int i = 0; i < numTrees; i++) { + futures.add(executor.submit(() -> { + IDecisionTree tree = new DecisionTree(); + Sample subsetSample = getRandomSubset(numberOfFeatures, features, labels); + tree.train(subsetSample.featureSamples(), subsetSample.labelSamples(), maxDepth, minSamplesSplit, minSamplesLeaf, maxLeafNodes); + return tree; + })); + } + + for (Future future : futures) { + try { + trees.add(future.get()); + } catch (ExecutionException | InterruptedException e) { + LOG.error("Failed to train a decision tree. Thread: " + + Thread.currentThread().getName(), e); + } + } + executor.shutdown(); + try { + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException ex) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + + private Sample getRandomSubset(int numberOfFeatures, double[][] features, double[][] labels) { + if (numberOfFeatures > features.length) { + throw new IllegalArgumentException("Number of feature must be between 1 and amount of features"); + } + double[][] subFeatures = new double[numberOfFeatures][]; + double[][] subLabels = new double[numberOfFeatures][]; + for (int i = 0; i < numberOfFeatures; i++) { + int randomIndex = threadLocalRandom.nextInt(numberOfFeatures); + subFeatures[i] = features[randomIndex]; + subLabels[i] = labels[randomIndex]; + } + + return new Sample(subFeatures, subLabels); + } + + + public double predict(double[] feature) { + ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + List> futures = new ArrayList<>(); + + for (IDecisionTree tree : trees) { + futures.add(executor.submit(() -> tree.predict(feature))); + } + + Map voteMap = new HashMap<>(); + for (Future future : futures) { + try { + double prediction = future.get(); + voteMap.merge(prediction, 1L, Long::sum); + } catch (InterruptedException | ExecutionException e) { + LOG.error("Failed to retrieve prediction from future task. Thread: " + + Thread.currentThread().getName(), e); + } + } + + executor.shutdown(); + try { + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException ex) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + + return voteMap.entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .orElseThrow(() -> new RuntimeException("Failed to find the most common prediction")); + } + + + public double evaluate(double[][] features, double[][] labels) { + int correctPredictions = 0; + for (int i = 0; i < features.length; i++) { + double predictedLabel = predict(features[i]); + double actualLabel = getIndexOfHighestValue(labels[i]); + if (predictedLabel == actualLabel) { + correctPredictions++; + } + } + return (double) correctPredictions / features.length; + } + + private double getIndexOfHighestValue(double[] labels) { + int maxIndex = 0; + double maxValue = labels[0]; + + for (int i = 1; i < labels.length; i++) { + if (labels[i] > maxValue) { + maxValue = labels[i]; + maxIndex = i; + } + } + return maxIndex; + } + +} diff --git a/lib/src/main/java/de/edux/ml/randomforest/Sample.java b/lib/src/main/java/de/edux/ml/randomforest/Sample.java new file mode 100644 index 0000000..5cdcf29 --- /dev/null +++ b/lib/src/main/java/de/edux/ml/randomforest/Sample.java @@ -0,0 +1,5 @@ +package de.edux.ml.randomforest; + +public record Sample(double[][] featureSamples, double[][] labelSamples) { + +} diff --git a/lib/src/test/java/de/edux/ml/RandomForestTest.java b/lib/src/test/java/de/edux/ml/RandomForestTest.java new file mode 100644 index 0000000..d59599c --- /dev/null +++ b/lib/src/test/java/de/edux/ml/RandomForestTest.java @@ -0,0 +1,61 @@ +package de.edux.ml; + +import de.edux.data.provider.Penguin; +import de.edux.data.provider.SeabornDataProcessor; +import de.edux.data.provider.SeabornProvider; +import de.edux.ml.randomforest.RandomForest; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.net.URL; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +class RandomForestTest { + private static final boolean SHUFFLE = true; + private static final boolean NORMALIZE = true; + private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; + private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; + private static SeabornProvider seabornProvider; + @BeforeAll + static void setup() { + URL url = RandomForestTest.class.getClassLoader().getResource(CSV_FILE_PATH); + if (url == null) { + throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); + } + File csvFile = new File(url.getPath()); + var seabornDataProcessor = new SeabornDataProcessor(); + var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + List> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); + seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); + } + @Test + void train() { + double[][] features = seabornProvider.getTrainFeatures(); + double[][] labels = seabornProvider.getTrainLabels(); + + double[][] testFeatures = seabornProvider.getTestFeatures(); + double[][] testLabels = seabornProvider.getTestLabels(); + + assertTrue(features.length > 0); + assertTrue(labels.length > 0); + assertTrue(testFeatures.length > 0); + assertTrue(testLabels.length > 0); + + int numberOfTrees = 100; + int maxDepth = 24; + int minSampleSize = 2; + int minSamplesLeaf = 1; + int maxLeafNodes = 12; + int numFeatures = (int) Math.sqrt(features.length)*3; + + RandomForest randomForest = new RandomForest(); + randomForest.train(numberOfTrees, features, labels, maxDepth, minSampleSize, minSamplesLeaf, maxLeafNodes,numFeatures); + double accuracy = randomForest.evaluate(testFeatures, testLabels); + System.out.println(accuracy); + assertTrue(accuracy>0.7); + } +} \ No newline at end of file diff --git a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java index 25b1d4e..99fcc4b 100644 --- a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java +++ b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java @@ -47,7 +47,11 @@ void train() { assertTrue(testLabels.length > 0); IDecisionTree decisionTree = new DecisionTree(); - decisionTree.train(features, labels, 10, 2, 1, 8); + int maxDepth = 10; + int minSampleSplit = 2; + int minSampleLeaf = 1; + int maxLeafNodes = 8; + decisionTree.train(features, labels, maxDepth, minSampleSplit, minSampleLeaf, maxLeafNodes); double accuracy = decisionTree.evaluate(testFeatures, testLabels); assertTrue(accuracy>0.7); }