From df62e30d49937427b3430596cf742a8eba84afc7 Mon Sep 17 00:00:00 2001 From: Samuel Abramov Date: Wed, 1 Nov 2023 12:17:18 +0100 Subject: [PATCH] chore():Reformat Codebase with Google Formatter #78 --- .../java/de/example/benchmark/Benchmark.java | 245 +++++---- .../DecisionTreeExampleOnIrisDataset.java | 46 +- .../java/de/example/knn/KnnIrisExample.java | 59 +- ...ayerNeuralNetworkExampleOnIrisDataset.java | 91 +-- .../RandomForestExampleOnIrisDataset.java | 56 +- .../example/svm/SVMExampleOnIrisDataset.java | 48 +- lib/build.gradle | 3 +- lib/src/main/java/de/edux/api/Classifier.java | 107 ++-- .../de/edux/data/provider/DataNormalizer.java | 90 ++- .../de/edux/data/provider/DataloaderV2.java | 4 +- .../de/edux/data/provider/Normalizer.java | 2 +- .../de/edux/data/reader/CSVIDataReader.java | 24 +- .../java/de/edux/data/reader/IDataReader.java | 2 +- .../activation/ActivationFunction.java | 195 +++---- .../functions/activation/package-info.java | 29 +- .../initialization/Initialization.java | 42 +- .../de/edux/functions/loss/LossFunction.java | 124 +++-- lib/src/main/java/de/edux/math/Entity.java | 9 +- lib/src/main/java/de/edux/math/MathUtil.java | 19 +- .../main/java/de/edux/math/Validations.java | 17 +- .../main/java/de/edux/math/entity/Matrix.java | 230 ++++---- .../main/java/de/edux/math/entity/Vector.java | 210 ++++--- .../de/edux/ml/decisiontree/DecisionTree.java | 516 +++++++++--------- .../java/de/edux/ml/decisiontree/Node.java | 22 +- .../de/edux/ml/decisiontree/package-info.java | 4 +- .../java/de/edux/ml/knn/KnnClassifier.java | 169 +++--- .../ml/nn/config/NetworkConfiguration.java | 16 +- .../ml/nn/network/MultilayerPerceptron.java | 400 +++++++------- .../java/de/edux/ml/nn/network/Neuron.java | 92 ++-- .../de/edux/ml/nn/network/api/INeuron.java | 9 +- .../edux/ml/nn/network/api/IPerceptron.java | 8 +- .../de/edux/ml/randomforest/RandomForest.java | 310 +++++------ .../de/edux/ml/randomforest/package-info.java | 4 +- .../de/edux/ml/svm/ISupportVectorMachine.java | 6 +- .../main/java/de/edux/ml/svm/SVMKernel.java | 2 +- .../main/java/de/edux/ml/svm/SVMModel.java | 68 +-- .../de/edux/ml/svm/SupportVectorMachine.java | 211 +++---- .../java/de/edux/ml/svm/package-info.java | 4 +- .../de/edux/util/LabelDimensionConverter.java | 18 +- .../math/ConcurrentMatrixMultiplication.java | 23 +- .../math/IncompatibleDimensionsException.java | 8 +- .../java/de/edux/util/math/MathMatrix.java | 97 ++-- .../de/edux/util/math/MatrixOperations.java | 82 +-- .../edux/data/provider/DataProcessorTest.java | 324 ++++++----- .../activation/ActivationFunctionTest.java | 143 ++--- .../de/edux/functions/InitializationTest.java | 43 +- .../edux/functions/loss/LossFunctionTest.java | 126 +++-- .../java/de/edux/math/entity/MatrixTest.java | 115 ++-- .../java/de/edux/math/entity/VectorTest.java | 73 ++- .../de/edux/ml/nn/network/NeuronTest.java | 67 +-- .../de/edux/util/math/MathMatrixTest.java | 188 +++---- 51 files changed, 2497 insertions(+), 2303 deletions(-) diff --git a/example/src/main/java/de/example/benchmark/Benchmark.java b/example/src/main/java/de/example/benchmark/Benchmark.java index b072aac..5548557 100644 --- a/example/src/main/java/de/example/benchmark/Benchmark.java +++ b/example/src/main/java/de/example/benchmark/Benchmark.java @@ -13,7 +13,6 @@ import de.edux.ml.randomforest.RandomForest; import de.edux.ml.svm.SVMKernel; import de.edux.ml.svm.SupportVectorMachine; - import java.io.File; import java.util.ArrayList; import java.util.List; @@ -21,119 +20,137 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.IntStream; -/** - * Compare the performance of different classifiers - */ +/** Compare the performance of different classifiers */ public class Benchmark { - private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); - private static final boolean SKIP_HEAD = true; - - private double[][] trainFeatures; - private double[][] trainLabels; - private double[][] testFeatures; - private double[][] testLabels; - private MultilayerPerceptron multilayerPerceptron; - private NetworkConfiguration networkConfiguration; - - public static void main(String[] args) { - new Benchmark().run(); - } - - private void run() { - initFeaturesAndLabels(); - - Classifier knn = new KnnClassifier(2); - Classifier decisionTree = new DecisionTree(2, 2, 3, 12); - Classifier randomForest = new RandomForest(500, 10, 2, 3, 3, 60); - Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1); - - networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128, 256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); - multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); - Map classifiers = Map.of( - "KNN", knn, - "DecisionTree", decisionTree, - "RandomForest", randomForest, - "SVM", svm, - "MLP", multilayerPerceptron - ); - - Map> results = new ConcurrentHashMap<>(); - results.put("KNN", new ArrayList<>()); - results.put("DecisionTree", new ArrayList<>()); - results.put("RandomForest", new ArrayList<>()); - results.put("SVM", new ArrayList<>()); - results.put("MLP", new ArrayList<>()); - - - IntStream.range(0, 5).forEach(i -> { - knn.train(trainFeatures, trainLabels); - decisionTree.train(trainFeatures, trainLabels); - randomForest.train(trainFeatures, trainLabels); - svm.train(trainFeatures, trainLabels); - multilayerPerceptron.train(trainFeatures, trainLabels); - - double knnAccuracy = knn.evaluate(testFeatures, testLabels); - double decisionTreeAccuracy = decisionTree.evaluate(testFeatures, testLabels); - double randomForestAccuracy = randomForest.evaluate(testFeatures, testLabels); - double svmAccuracy = svm.evaluate(testFeatures, testLabels); - double multilayerPerceptronAccuracy = multilayerPerceptron.evaluate(testFeatures, testLabels); - - results.get("KNN").add(knnAccuracy); - results.get("DecisionTree").add(decisionTreeAccuracy); - results.get("RandomForest").add(randomForestAccuracy); - results.get("SVM").add(svmAccuracy); - results.get("MLP").add(multilayerPerceptronAccuracy); - initFeaturesAndLabels(); - updateMLP(testFeatures, testLabels); - }); - - System.out.println("Classifier performances (sorted by average accuracy):"); - results.entrySet().stream() - .map(entry -> { - double avgAccuracy = entry.getValue().stream() - .mapToDouble(Double::doubleValue) - .average() - .orElse(0.0); - return Map.entry(entry.getKey(), avgAccuracy); - }) - .sorted(Map.Entry.comparingByValue().reversed()) - .forEachOrdered(entry -> { - System.out.printf("%s: %.2f%%\n", entry.getKey(), entry.getValue() * 100); - }); - - System.out.println("\nClassifier best and worst performances:"); - results.forEach((classifierName, accuracies) -> { - double maxAccuracy = accuracies.stream() - .mapToDouble(Double::doubleValue) - .max() - .orElse(0.0); - double minAccuracy = accuracies.stream() - .mapToDouble(Double::doubleValue) - .min() - .orElse(0.0); - System.out.printf("%s: Best: %.2f%%, Worst: %.2f%%\n", classifierName, maxAccuracy * 100, minAccuracy * 100); + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = + new File( + "example" + + File.separator + + "datasets" + + File.separator + + "iris" + + File.separator + + "iris.csv"); + private static final boolean SKIP_HEAD = true; + + private double[][] trainFeatures; + private double[][] trainLabels; + private double[][] testFeatures; + private double[][] testLabels; + private MultilayerPerceptron multilayerPerceptron; + private NetworkConfiguration networkConfiguration; + + public static void main(String[] args) { + new Benchmark().run(); + } + + private void run() { + initFeaturesAndLabels(); + + Classifier knn = new KnnClassifier(2); + Classifier decisionTree = new DecisionTree(2, 2, 3, 12); + Classifier randomForest = new RandomForest(500, 10, 2, 3, 3, 60); + Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1); + + networkConfiguration = + new NetworkConfiguration( + trainFeatures[0].length, + List.of(128, 256, 512), + 3, + 0.01, + 300, + ActivationFunction.LEAKY_RELU, + ActivationFunction.SOFTMAX, + LossFunction.CATEGORICAL_CROSS_ENTROPY, + Initialization.XAVIER, + Initialization.XAVIER); + multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + Map classifiers = + Map.of( + "KNN", knn, + "DecisionTree", decisionTree, + "RandomForest", randomForest, + "SVM", svm, + "MLP", multilayerPerceptron); + + Map> results = new ConcurrentHashMap<>(); + results.put("KNN", new ArrayList<>()); + results.put("DecisionTree", new ArrayList<>()); + results.put("RandomForest", new ArrayList<>()); + results.put("SVM", new ArrayList<>()); + results.put("MLP", new ArrayList<>()); + + IntStream.range(0, 5) + .forEach( + i -> { + knn.train(trainFeatures, trainLabels); + decisionTree.train(trainFeatures, trainLabels); + randomForest.train(trainFeatures, trainLabels); + svm.train(trainFeatures, trainLabels); + multilayerPerceptron.train(trainFeatures, trainLabels); + + double knnAccuracy = knn.evaluate(testFeatures, testLabels); + double decisionTreeAccuracy = decisionTree.evaluate(testFeatures, testLabels); + double randomForestAccuracy = randomForest.evaluate(testFeatures, testLabels); + double svmAccuracy = svm.evaluate(testFeatures, testLabels); + double multilayerPerceptronAccuracy = + multilayerPerceptron.evaluate(testFeatures, testLabels); + + results.get("KNN").add(knnAccuracy); + results.get("DecisionTree").add(decisionTreeAccuracy); + results.get("RandomForest").add(randomForestAccuracy); + results.get("SVM").add(svmAccuracy); + results.get("MLP").add(multilayerPerceptronAccuracy); + initFeaturesAndLabels(); + updateMLP(testFeatures, testLabels); + }); + + System.out.println("Classifier performances (sorted by average accuracy):"); + results.entrySet().stream() + .map( + entry -> { + double avgAccuracy = + entry.getValue().stream().mapToDouble(Double::doubleValue).average().orElse(0.0); + return Map.entry(entry.getKey(), avgAccuracy); + }) + .sorted(Map.Entry.comparingByValue().reversed()) + .forEachOrdered( + entry -> { + System.out.printf("%s: %.2f%%\n", entry.getKey(), entry.getValue() * 100); + }); + + System.out.println("\nClassifier best and worst performances:"); + results.forEach( + (classifierName, accuracies) -> { + double maxAccuracy = + accuracies.stream().mapToDouble(Double::doubleValue).max().orElse(0.0); + double minAccuracy = + accuracies.stream().mapToDouble(Double::doubleValue).min().orElse(0.0); + System.out.printf( + "%s: Best: %.2f%%, Worst: %.2f%%\n", + classifierName, maxAccuracy * 100, minAccuracy * 100); }); - } - - private void updateMLP(double[][] testFeatures, double[][] testLabels) { - multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); - } - - private void initFeaturesAndLabels() { - var featureColumnIndices = new int[]{0, 1, 2, 3}; - var targetColumnIndex = 4; - - var dataProcessor = new DataProcessor(new CSVIDataReader()) - .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) - .normalize() - .shuffle() - .split(TRAIN_TEST_SPLIT_RATIO); - - trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices); - trainLabels = dataProcessor.getTrainLabels(targetColumnIndex); - testFeatures = dataProcessor.getTestFeatures(featureColumnIndices); - testLabels = dataProcessor.getTestLabels(targetColumnIndex); - - } + } + + private void updateMLP(double[][] testFeatures, double[][] testLabels) { + multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + } + + private void initFeaturesAndLabels() { + var featureColumnIndices = new int[] {0, 1, 2, 3}; + var targetColumnIndex = 4; + + var dataProcessor = + new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); + + trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices); + trainLabels = dataProcessor.getTrainLabels(targetColumnIndex); + testFeatures = dataProcessor.getTestFeatures(featureColumnIndices); + testLabels = dataProcessor.getTestLabels(targetColumnIndex); + } } diff --git a/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnIrisDataset.java b/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnIrisDataset.java index 5254a6c..65e0a11 100644 --- a/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnIrisDataset.java +++ b/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnIrisDataset.java @@ -4,15 +4,22 @@ import de.edux.data.provider.DataProcessor; import de.edux.data.reader.CSVIDataReader; import de.edux.ml.decisiontree.DecisionTree; -import de.edux.ml.randomforest.RandomForest; - import java.io.File; public class DecisionTreeExampleOnIrisDataset { - private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); - private static final boolean SKIP_HEAD = true; - public static void main(String[] args) { + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = + new File( + "example" + + File.separator + + "datasets" + + File.separator + + "iris" + + File.separator + + "iris.csv"); + private static final boolean SKIP_HEAD = true; + + public static void main(String[] args) { /* IRIS Dataset... +-------------+------------+-------------+------------+---------+ | sepal.length| sepal.width| petal.length| petal.width| variety | @@ -20,18 +27,23 @@ public static void main(String[] args) { | 5.1 | 3.5 | 1.4 | .2 | Setosa | +-------------+------------+-------------+------------+---------+ */ - var featureColumnIndices = new int[]{0, 1, 2, 3}; // First 4 columns are features - var targetColumnIndex = 4; // Last column is the target + var featureColumnIndices = new int[] {0, 1, 2, 3}; // First 4 columns are features + var targetColumnIndex = 4; // Last column is the target - var irisDataProcessor = new DataProcessor(new CSVIDataReader()).loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex).normalize().shuffle().split(TRAIN_TEST_SPLIT_RATIO); - Classifier classifier = new DecisionTree(2, 2, 3, 12); + var irisDataProcessor = + new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); + Classifier classifier = new DecisionTree(2, 2, 3, 12); - var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); - var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); - var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); - var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); + var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); + var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); + var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); + var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); - classifier.train(trainFeatures, trainLabels); - classifier.evaluate(trainTestFeatures, trainTestLabels); - } + classifier.train(trainFeatures, trainLabels); + classifier.evaluate(trainTestFeatures, trainTestLabels); + } } diff --git a/example/src/main/java/de/example/knn/KnnIrisExample.java b/example/src/main/java/de/example/knn/KnnIrisExample.java index de31f91..4f86c60 100644 --- a/example/src/main/java/de/example/knn/KnnIrisExample.java +++ b/example/src/main/java/de/example/knn/KnnIrisExample.java @@ -4,33 +4,40 @@ import de.edux.data.provider.DataProcessor; import de.edux.data.reader.CSVIDataReader; import de.edux.ml.knn.KnnClassifier; - import java.io.File; public class KnnIrisExample { - private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); - private static final boolean SKIP_HEAD = true; - - public static void main(String[] args) { - var featureColumnIndices = new int[]{0, 1, 2, 3}; - var targetColumnIndex = 4; - - var irisDataProcessor = new DataProcessor(new CSVIDataReader()) - .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) - .normalize() - .shuffle() - .split(TRAIN_TEST_SPLIT_RATIO); - - - Classifier knn = new KnnClassifier(2); - - var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); - var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); - var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); - var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); - - knn.train(trainFeatures, trainLabels); - knn.evaluate(trainTestFeatures, trainTestLabels); - } + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = + new File( + "example" + + File.separator + + "datasets" + + File.separator + + "iris" + + File.separator + + "iris.csv"); + private static final boolean SKIP_HEAD = true; + + public static void main(String[] args) { + var featureColumnIndices = new int[] {0, 1, 2, 3}; + var targetColumnIndex = 4; + + var irisDataProcessor = + new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); + + Classifier knn = new KnnClassifier(2); + + var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); + var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); + var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); + var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); + + knn.train(trainFeatures, trainLabels); + knn.evaluate(trainTestFeatures, trainTestLabels); + } } diff --git a/example/src/main/java/de/example/nn/MultilayerNeuralNetworkExampleOnIrisDataset.java b/example/src/main/java/de/example/nn/MultilayerNeuralNetworkExampleOnIrisDataset.java index 450864a..6384f1c 100644 --- a/example/src/main/java/de/example/nn/MultilayerNeuralNetworkExampleOnIrisDataset.java +++ b/example/src/main/java/de/example/nn/MultilayerNeuralNetworkExampleOnIrisDataset.java @@ -7,52 +7,71 @@ import de.edux.functions.loss.LossFunction; import de.edux.ml.nn.config.NetworkConfiguration; import de.edux.ml.nn.network.MultilayerPerceptron; - import java.io.File; import java.util.List; public class MultilayerNeuralNetworkExampleOnIrisDataset { - private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); - private static final boolean SKIP_HEAD = true; - - public static void main(String[] args) { - var featureColumnIndices = new int[]{0, 1, 2, 3}; - var targetColumnIndex = 4; - - var dataProcessor = new DataProcessor(new CSVIDataReader()); - var dataset = dataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex); - dataset.shuffle(); - dataset.normalize(); - dataProcessor.split(TRAIN_TEST_SPLIT_RATIO); + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = + new File( + "example" + + File.separator + + "datasets" + + File.separator + + "iris" + + File.separator + + "iris.csv"); + private static final boolean SKIP_HEAD = true; + public static void main(String[] args) { + var featureColumnIndices = new int[] {0, 1, 2, 3}; + var targetColumnIndex = 4; + var dataProcessor = new DataProcessor(new CSVIDataReader()); + var dataset = + dataProcessor.loadDataSetFromCSV( + CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex); + dataset.shuffle(); + dataset.normalize(); + dataProcessor.split(TRAIN_TEST_SPLIT_RATIO); - var trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices); - var trainLabels = dataProcessor.getTrainLabels(targetColumnIndex); - var testFeatures = dataProcessor.getTestFeatures(featureColumnIndices); - var testLabels = dataProcessor.getTestLabels(targetColumnIndex); + var trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices); + var trainLabels = dataProcessor.getTrainLabels(targetColumnIndex); + var testFeatures = dataProcessor.getTestFeatures(featureColumnIndices); + var testLabels = dataProcessor.getTestLabels(targetColumnIndex); - var classMap = dataProcessor.getClassMap(); + var classMap = dataProcessor.getClassMap(); - System.out.println("Class Map: " + classMap); + System.out.println("Class Map: " + classMap); - //Configure Network with: - // - 4 Input Neurons - // - 2 Hidden Layer with 12 and 6 Neurons - // - 3 Output Neurons - // - Learning Rate of 0.1 - // - 1000 Epochs - // - Leaky ReLU as Activation Function for Hidden Layers - // - Softmax as Activation Function for Output Layer - // - Categorical Cross Entropy as Loss Function - // - Xavier as Weight Initialization for Hidden Layers - // - Xavier as Weight Initialization for Output Layer - var networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128, 256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); + // Configure Network with: + // - 4 Input Neurons + // - 2 Hidden Layer with 12 and 6 Neurons + // - 3 Output Neurons + // - Learning Rate of 0.1 + // - 1000 Epochs + // - Leaky ReLU as Activation Function for Hidden Layers + // - Softmax as Activation Function for Output Layer + // - Categorical Cross Entropy as Loss Function + // - Xavier as Weight Initialization for Hidden Layers + // - Xavier as Weight Initialization for Output Layer + var networkConfiguration = + new NetworkConfiguration( + trainFeatures[0].length, + List.of(128, 256, 512), + 3, + 0.01, + 300, + ActivationFunction.LEAKY_RELU, + ActivationFunction.SOFTMAX, + LossFunction.CATEGORICAL_CROSS_ENTROPY, + Initialization.XAVIER, + Initialization.XAVIER); - MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); - multilayerPerceptron.train(trainFeatures, trainLabels); - multilayerPerceptron.evaluate(testFeatures, testLabels); - } + MultilayerPerceptron multilayerPerceptron = + new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + multilayerPerceptron.train(trainFeatures, trainLabels); + multilayerPerceptron.evaluate(testFeatures, testLabels); + } } diff --git a/example/src/main/java/de/example/randomforest/RandomForestExampleOnIrisDataset.java b/example/src/main/java/de/example/randomforest/RandomForestExampleOnIrisDataset.java index 0394f8b..99aea75 100644 --- a/example/src/main/java/de/example/randomforest/RandomForestExampleOnIrisDataset.java +++ b/example/src/main/java/de/example/randomforest/RandomForestExampleOnIrisDataset.java @@ -4,18 +4,23 @@ import de.edux.data.provider.DataProcessor; import de.edux.data.reader.CSVIDataReader; import de.edux.ml.randomforest.RandomForest; -import de.edux.ml.svm.SVMKernel; -import de.edux.ml.svm.SupportVectorMachine; - import java.io.File; public class RandomForestExampleOnIrisDataset { - private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); - private static final boolean SKIP_HEAD = true; - - public static void main(String[] args) { + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = + new File( + "example" + + File.separator + + "datasets" + + File.separator + + "iris" + + File.separator + + "iris.csv"); + private static final boolean SKIP_HEAD = true; + + public static void main(String[] args) { /* IRIS Dataset... +-------------+------------+-------------+------------+---------+ | sepal.length| sepal.width| petal.length| petal.width| variety | @@ -23,19 +28,24 @@ public static void main(String[] args) { | 5.1 | 3.5 | 1.4 | .2 | Setosa | +-------------+------------+-------------+------------+---------+ */ - var featureColumnIndices = new int[]{0, 1, 2, 3}; // First 4 columns are features - var targetColumnIndex = 4; // Last column is the target - - var irisDataProcessor = new DataProcessor(new CSVIDataReader()).loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex).normalize().shuffle().split(TRAIN_TEST_SPLIT_RATIO); - /* train 100 decision trees with max depth of 64, min samples split of 2, min samples leaf of 1, max features of 3 and 50 of samples */ - Classifier randomForest = new RandomForest(1000, 64, 2, 2, 3, 50); - - var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); - var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); - var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); - var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); - - randomForest.train(trainFeatures, trainLabels); - randomForest.evaluate(trainTestFeatures, trainTestLabels); - } + var featureColumnIndices = new int[] {0, 1, 2, 3}; // First 4 columns are features + var targetColumnIndex = 4; // Last column is the target + + var irisDataProcessor = + new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); + /* train 100 decision trees with max depth of 64, min samples split of 2, min samples leaf of 1, max features of 3 and 50 of samples */ + Classifier randomForest = new RandomForest(1000, 64, 2, 2, 3, 50); + + var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); + var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); + var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); + var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); + + randomForest.train(trainFeatures, trainLabels); + randomForest.evaluate(trainTestFeatures, trainTestLabels); + } } diff --git a/example/src/main/java/de/example/svm/SVMExampleOnIrisDataset.java b/example/src/main/java/de/example/svm/SVMExampleOnIrisDataset.java index 7e8b8c4..85d1b94 100644 --- a/example/src/main/java/de/example/svm/SVMExampleOnIrisDataset.java +++ b/example/src/main/java/de/example/svm/SVMExampleOnIrisDataset.java @@ -5,15 +5,22 @@ import de.edux.data.reader.CSVIDataReader; import de.edux.ml.svm.SVMKernel; import de.edux.ml.svm.SupportVectorMachine; - import java.io.File; public class SVMExampleOnIrisDataset { - private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); - private static final boolean SKIP_HEAD = true; - - public static void main(String[] args) { + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = + new File( + "example" + + File.separator + + "datasets" + + File.separator + + "iris" + + File.separator + + "iris.csv"); + private static final boolean SKIP_HEAD = true; + + public static void main(String[] args) { /* IRIS Dataset...*/ /* +-------------+------------+-------------+------------+---------+ @@ -23,21 +30,24 @@ public static void main(String[] args) { +-------------+------------+-------------+------------+---------+ */ - var featureColumnIndices = new int[]{0, 1, 2, 3}; // First 4 columns are features - var targetColumnIndex = 4; // Last column is the target - - var irisDataProcessor = new DataProcessor(new CSVIDataReader()).loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex).normalize().shuffle().split(TRAIN_TEST_SPLIT_RATIO); - + var featureColumnIndices = new int[] {0, 1, 2, 3}; // First 4 columns are features + var targetColumnIndex = 4; // Last column is the target - Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 2); + var irisDataProcessor = + new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); - var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); - var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); - var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); - var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); + Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 2); - svm.train(trainFeatures, trainLabels); - svm.evaluate(trainTestFeatures, trainTestLabels); - } + var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); + var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); + var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); + var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); + svm.train(trainFeatures, trainLabels); + svm.evaluate(trainTestFeatures, trainTestLabels); + } } diff --git a/lib/build.gradle b/lib/build.gradle index d6e684e..efd23df 100644 --- a/lib/build.gradle +++ b/lib/build.gradle @@ -48,7 +48,8 @@ task sourceJar(type: Jar) { task javadocJar(type: Jar, dependsOn: javadoc) { archiveClassifier.set('javadoc') - javadoc.destinationDir = file("${rootProject.projectDir}/docs/javadocs") // Definiert das Ausgabeverzeichnis für die generierten JavaDocs. + javadoc.destinationDir = file("${rootProject.projectDir}/docs/javadocs") + // Definiert das Ausgabeverzeichnis für die generierten JavaDocs. from javadoc.destinationDir } diff --git a/lib/src/main/java/de/edux/api/Classifier.java b/lib/src/main/java/de/edux/api/Classifier.java index 8d4f8f2..2c869cf 100644 --- a/lib/src/main/java/de/edux/api/Classifier.java +++ b/lib/src/main/java/de/edux/api/Classifier.java @@ -3,69 +3,68 @@ /** * Provides a common interface for machine learning classifiers within the Edux API. * - *

The {@code Classifier} interface is designed to encapsulate a variety of machine learning models, offering - * a consistent approach to training, evaluating, and utilizing classifiers. Implementations are expected to handle - * specifics related to different types of classification algorithms, such as neural networks, decision trees, - * support vector machines, etc.

+ *

The {@code Classifier} interface is designed to encapsulate a variety of machine learning + * models, offering a consistent approach to training, evaluating, and utilizing classifiers. + * Implementations are expected to handle specifics related to different types of classification + * algorithms, such as neural networks, decision trees, support vector machines, etc. * - *

Each classifier must implement methods for training the model on a dataset, evaluating its performance, - * and making predictions on new, unseen data. This design allows for interchangeability of models and promotes - * a clean separation of concerns between the data processing and model training phases.

+ *

Each classifier must implement methods for training the model on a dataset, evaluating its + * performance, and making predictions on new, unseen data. This design allows for + * interchangeability of models and promotes a clean separation of concerns between the data + * processing and model training phases. + * + *

Typical usage involves: * - *

Typical usage involves:

*
    - *
  • Creating an instance of a class that implements {@code Classifier}.
  • - *
  • Training the classifier with known data via the {@code train} method.
  • - *
  • Evaluating the classifier's performance with test data via the {@code evaluate} method.
  • - *
  • Applying the trained classifier to new data to make predictions via the {@code predict} method.
  • + *
  • Creating an instance of a class that implements {@code Classifier}. + *
  • Training the classifier with known data via the {@code train} method. + *
  • Evaluating the classifier's performance with test data via the {@code evaluate} method. + *
  • Applying the trained classifier to new data to make predictions via the {@code predict} + * method. *
* *

Implementing classes should ensure that proper validation is performed on the input data and - * any necessary pre-processing or feature scaling is applied consistent with the model's requirements.

- * + * any necessary pre-processing or feature scaling is applied consistent with the model's + * requirements. */ public interface Classifier { - /** - * Trains the model using the provided training inputs and targets. - * @param features 2D array of double, where each inner array represents - * @param labels 2D array of double, where each inner array represents - * @return true if the model was successfully trained, false otherwise. - */ - boolean train(double[][] features, double[][] labels); - /** - * Evaluates the model's performance against the provided test inputs and targets. - * - * This method takes a set of test inputs and their corresponding expected targets, - * applies the model to predict the outputs for the inputs, and then compares - * the predicted outputs to the expected targets to evaluate the performance - * of the model. The nature and metric of the evaluation (e.g., accuracy, MSE, etc.) - * are dependent on the specific implementation within the method. - * - * @param testInputs 2D array of double, where each inner array represents - * a single set of input values to be evaluated by the model. - * @param testTargets 2D array of double, where each inner array represents - * the expected output or target for the corresponding set - * of inputs in {@code testInputs}. - * @return a double value representing the performance of the model when evaluated - * against the provided test inputs and targets. The interpretation of this - * value (e.g., higher is better, lower is better, etc.) depends on the - * specific evaluation metric being used. - * @throws IllegalArgumentException if the lengths of {@code testInputs} and - * {@code testTargets} do not match, or if - * they are empty. - */ - double evaluate(double[][] testInputs, double[][] testTargets); - + /** + * Trains the model using the provided training inputs and targets. + * + * @param features 2D array of double, where each inner array represents + * @param labels 2D array of double, where each inner array represents + * @return true if the model was successfully trained, false otherwise. + */ + boolean train(double[][] features, double[][] labels); - /** - * Predicts the output for a single set of input values. - * - * @param feature a single set of input values to be evaluated by the model. - * @return a double array representing the predicted output values for the - * provided input values. - * @throws IllegalArgumentException if {@code feature} is empty. - */ - public double[] predict(double[] feature); + /** + * Evaluates the model's performance against the provided test inputs and targets. + * + *

This method takes a set of test inputs and their corresponding expected targets, applies the + * model to predict the outputs for the inputs, and then compares the predicted outputs to the + * expected targets to evaluate the performance of the model. The nature and metric of the + * evaluation (e.g., accuracy, MSE, etc.) are dependent on the specific implementation within the + * method. + * + * @param testInputs 2D array of double, where each inner array represents a single set of input + * values to be evaluated by the model. + * @param testTargets 2D array of double, where each inner array represents the expected output or + * target for the corresponding set of inputs in {@code testInputs}. + * @return a double value representing the performance of the model when evaluated against the + * provided test inputs and targets. The interpretation of this value (e.g., higher is better, + * lower is better, etc.) depends on the specific evaluation metric being used. + * @throws IllegalArgumentException if the lengths of {@code testInputs} and {@code testTargets} + * do not match, or if they are empty. + */ + double evaluate(double[][] testInputs, double[][] testTargets); + /** + * Predicts the output for a single set of input values. + * + * @param feature a single set of input values to be evaluated by the model. + * @return a double array representing the predicted output values for the provided input values. + * @throws IllegalArgumentException if {@code feature} is empty. + */ + public double[] predict(double[] feature); } diff --git a/lib/src/main/java/de/edux/data/provider/DataNormalizer.java b/lib/src/main/java/de/edux/data/provider/DataNormalizer.java index 0df8972..0078a81 100644 --- a/lib/src/main/java/de/edux/data/provider/DataNormalizer.java +++ b/lib/src/main/java/de/edux/data/provider/DataNormalizer.java @@ -1,62 +1,60 @@ package de.edux.data.provider; import java.util.List; -import java.util.ArrayList; public class DataNormalizer implements Normalizer { - @Override - public List normalize(List dataset) { - if (dataset == null || dataset.isEmpty()) { - return dataset; - } + @Override + public List normalize(List dataset) { + if (dataset == null || dataset.isEmpty()) { + return dataset; + } - int columnCount = dataset.get(0).length; + int columnCount = dataset.get(0).length; - double[] minValues = new double[columnCount]; - double[] maxValues = new double[columnCount]; - boolean[] isNumericColumn = new boolean[columnCount]; + double[] minValues = new double[columnCount]; + double[] maxValues = new double[columnCount]; + boolean[] isNumericColumn = new boolean[columnCount]; - for (int i = 0; i < columnCount; i++) { - minValues[i] = Double.MAX_VALUE; - maxValues[i] = -Double.MAX_VALUE; - isNumericColumn[i] = true; - } + for (int i = 0; i < columnCount; i++) { + minValues[i] = Double.MAX_VALUE; + maxValues[i] = -Double.MAX_VALUE; + isNumericColumn[i] = true; + } - for (String[] row : dataset) { - for (int colIndex = 0; colIndex < columnCount; colIndex++) { - try { - double numValue = Double.parseDouble(row[colIndex]); - - if (numValue < minValues[colIndex]) { - minValues[colIndex] = numValue; - } - if (numValue > maxValues[colIndex]) { - maxValues[colIndex] = numValue; - } - } catch (NumberFormatException e) { - isNumericColumn[colIndex] = false; - } - } + for (String[] row : dataset) { + for (int colIndex = 0; colIndex < columnCount; colIndex++) { + try { + double numValue = Double.parseDouble(row[colIndex]); + + if (numValue < minValues[colIndex]) { + minValues[colIndex] = numValue; + } + if (numValue > maxValues[colIndex]) { + maxValues[colIndex] = numValue; + } + } catch (NumberFormatException e) { + isNumericColumn[colIndex] = false; } + } + } - for (String[] row : dataset) { - for (int colIndex = 0; colIndex < columnCount; colIndex++) { - if (isNumericColumn[colIndex]) { - double numValue = Double.parseDouble(row[colIndex]); - double range = maxValues[colIndex] - minValues[colIndex]; - - if (range != 0.0) { - double normalized = (numValue - minValues[colIndex]) / range; - row[colIndex] = String.valueOf(normalized); - } else { - row[colIndex] = "0"; - } - } - } + for (String[] row : dataset) { + for (int colIndex = 0; colIndex < columnCount; colIndex++) { + if (isNumericColumn[colIndex]) { + double numValue = Double.parseDouble(row[colIndex]); + double range = maxValues[colIndex] - minValues[colIndex]; + + if (range != 0.0) { + double normalized = (numValue - minValues[colIndex]) / range; + row[colIndex] = String.valueOf(normalized); + } else { + row[colIndex] = "0"; + } } - - return dataset; + } } + return dataset; + } } diff --git a/lib/src/main/java/de/edux/data/provider/DataloaderV2.java b/lib/src/main/java/de/edux/data/provider/DataloaderV2.java index bfd0377..f7ff2d1 100644 --- a/lib/src/main/java/de/edux/data/provider/DataloaderV2.java +++ b/lib/src/main/java/de/edux/data/provider/DataloaderV2.java @@ -3,6 +3,6 @@ import java.io.File; public interface DataloaderV2 { - DataProcessor loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHead, int[] inputColumns, int targetColumn); + DataProcessor loadDataSetFromCSV( + File csvFile, char csvSeparator, boolean skipHead, int[] inputColumns, int targetColumn); } - diff --git a/lib/src/main/java/de/edux/data/provider/Normalizer.java b/lib/src/main/java/de/edux/data/provider/Normalizer.java index 0f0b3be..543766a 100644 --- a/lib/src/main/java/de/edux/data/provider/Normalizer.java +++ b/lib/src/main/java/de/edux/data/provider/Normalizer.java @@ -3,5 +3,5 @@ import java.util.List; public interface Normalizer { - List normalize(List dataset); + List normalize(List dataset); } diff --git a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java index 2c16dc7..c1c0016 100644 --- a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java +++ b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java @@ -5,7 +5,6 @@ import com.opencsv.CSVReader; import com.opencsv.CSVReaderBuilder; import com.opencsv.exceptions.CsvException; - import java.io.File; import java.io.FileReader; import java.io.IOException; @@ -13,18 +12,15 @@ public class CSVIDataReader implements IDataReader { - public List readFile(File file, char separator ) { - CSVParser customCSVParser = new CSVParserBuilder().withSeparator(separator).build(); - List result; - try(CSVReader reader = new CSVReaderBuilder( - new FileReader(file)) - .withCSVParser(customCSVParser) - .build()){ - result = reader.readAll(); - } catch (CsvException | IOException e) { - throw new RuntimeException(e); - } - return result; + public List readFile(File file, char separator) { + CSVParser customCSVParser = new CSVParserBuilder().withSeparator(separator).build(); + List result; + try (CSVReader reader = + new CSVReaderBuilder(new FileReader(file)).withCSVParser(customCSVParser).build()) { + result = reader.readAll(); + } catch (CsvException | IOException e) { + throw new RuntimeException(e); } - + return result; + } } diff --git a/lib/src/main/java/de/edux/data/reader/IDataReader.java b/lib/src/main/java/de/edux/data/reader/IDataReader.java index 7a01c92..c873fa0 100644 --- a/lib/src/main/java/de/edux/data/reader/IDataReader.java +++ b/lib/src/main/java/de/edux/data/reader/IDataReader.java @@ -4,5 +4,5 @@ import java.util.List; public interface IDataReader { - List readFile(File file, char separator); + List readFile(File file, char separator); } diff --git a/lib/src/main/java/de/edux/functions/activation/ActivationFunction.java b/lib/src/main/java/de/edux/functions/activation/ActivationFunction.java index 4dac09c..3706614 100644 --- a/lib/src/main/java/de/edux/functions/activation/ActivationFunction.java +++ b/lib/src/main/java/de/edux/functions/activation/ActivationFunction.java @@ -1,120 +1,123 @@ package de.edux.functions.activation; + /** - * Enumerates common activation functions used in neural networks and similar machine learning architectures. + * Enumerates common activation functions used in neural networks and similar machine learning + * architectures. + * + *

Each member of this enum represents a distinct type of activation function, a critical + * component in neural networks. Activation functions determine the output of a neural network layer + * for a given set of input, and they help normalize the output of each neuron to a specific range, + * usually between 1 and -1 or between 1 and 0. * - *

Each member of this enum represents a distinct type of activation function, a critical component in - * neural networks. Activation functions determine the output of a neural network layer for a given set of - * input, and they help normalize the output of each neuron to a specific range, usually between 1 and -1 or - * between 1 and 0.

+ *

This enum simplifies the process of selecting and utilizing an activation function. It + * provides an abstraction where the user can easily switch between different functions, making it + * easier to experiment with neural network design. Additionally, each function includes a method + * for calculating its derivative, which is essential for backpropagation in neural network + * training. * - *

This enum simplifies the process of selecting and utilizing an activation function. It provides an - * abstraction where the user can easily switch between different functions, making it easier to experiment - * with neural network design. Additionally, each function includes a method for calculating its derivative, - * which is essential for backpropagation in neural network training.

+ *

Available functions include: * - *

Available functions include:

*
    - *
  • SIGMOID: Normalizes inputs between 0 and 1, crucial for binary classification.
  • - *
  • RELU: Addresses the vanishing gradient problem, allowing for faster and more effective training.
  • - *
  • LEAKY_RELU: Variation of RELU, prevents "dying neurons" by allowing a small gradient when the unit is not active.
  • - *
  • TANH: Normalizes inputs between -1 and 1, a scaled version of the sigmoid function.
  • - *
  • SOFTMAX: Converts a vector of raw scores to a probability distribution, typically used in multi-class classification.
  • + *
  • SIGMOID: Normalizes inputs between 0 and 1, crucial for binary classification. + *
  • RELU: Addresses the vanishing gradient problem, allowing for faster and more + * effective training. + *
  • LEAKY_RELU: Variation of RELU, prevents "dying neurons" by allowing a small gradient + * when the unit is not active. + *
  • TANH: Normalizes inputs between -1 and 1, a scaled version of the sigmoid function. + *
  • SOFTMAX: Converts a vector of raw scores to a probability distribution, typically + * used in multi-class classification. *
* - *

Each function overrides the {@code calculateActivation} and {@code calculateDerivative} methods, providing the - * specific implementation for the activation and its derivative based on input. These are essential for the forward - * and backward passes through the network, respectively.

- * - *

Note: The {@code SOFTMAX} function additionally overrides {@code calculateActivation} for an array input, - * facilitating its common use in output layers of neural networks for classification tasks.

+ *

Each function overrides the {@code calculateActivation} and {@code calculateDerivative} + * methods, providing the specific implementation for the activation and its derivative based on + * input. These are essential for the forward and backward passes through the network, respectively. * + *

Note: The {@code SOFTMAX} function additionally overrides {@code calculateActivation} + * for an array input, facilitating its common use in output layers of neural networks for + * classification tasks. */ public enum ActivationFunction { + SIGMOID { + @Override + public double calculateActivation(double x) { + return 1 / (1 + Math.exp(-x)); + } - SIGMOID { - @Override - public double calculateActivation(double x) { - return 1 / (1 + Math.exp(-x)); - } - - @Override - public double calculateDerivative(double x) { - return calculateActivation(x) * (1 - calculateActivation(x)); - } - }, - RELU { - @Override - public double calculateActivation(double x) { - return Math.max(0, x); - } - - @Override - public double calculateDerivative(double x) { - return x > 0 ? 1 : 0; - } - }, - LEAKY_RELU { - @Override - public double calculateActivation(double x) { - return Math.max(0.01 * x, x); - } + @Override + public double calculateDerivative(double x) { + return calculateActivation(x) * (1 - calculateActivation(x)); + } + }, + RELU { + @Override + public double calculateActivation(double x) { + return Math.max(0, x); + } - @Override - public double calculateDerivative(double x) { - if (x > 0) { - return 1.0; - } else { - return 0.01; - } - } - }, - TANH { - @Override - public double calculateActivation(double x) { - return Math.tanh(x); - } + @Override + public double calculateDerivative(double x) { + return x > 0 ? 1 : 0; + } + }, + LEAKY_RELU { + @Override + public double calculateActivation(double x) { + return Math.max(0.01 * x, x); + } - @Override - public double calculateDerivative(double x) { - return 1 - Math.pow(calculateActivation(x), 2); - } - }, SOFTMAX { - @Override - public double calculateActivation(double x) { - return Math.exp(x); - } + @Override + public double calculateDerivative(double x) { + if (x > 0) { + return 1.0; + } else { + return 0.01; + } + } + }, + TANH { + @Override + public double calculateActivation(double x) { + return Math.tanh(x); + } - @Override - public double calculateDerivative(double x) { - return calculateActivation(x) * (1 - calculateActivation(x)); - } + @Override + public double calculateDerivative(double x) { + return 1 - Math.pow(calculateActivation(x), 2); + } + }, + SOFTMAX { + @Override + public double calculateActivation(double x) { + return Math.exp(x); + } - @Override - public double[] calculateActivation(double[] x) { - double max = Double.NEGATIVE_INFINITY; - for (double value : x) - if (value > max) - max = value; + @Override + public double calculateDerivative(double x) { + return calculateActivation(x) * (1 - calculateActivation(x)); + } - double sum = 0.0; - for (int i = 0; i < x.length; i++) { - x[i] = Math.exp(x[i] - max); - sum += x[i]; - } + @Override + public double[] calculateActivation(double[] x) { + double max = Double.NEGATIVE_INFINITY; + for (double value : x) if (value > max) max = value; - for (int i = 0; i < x.length; i++) - x[i] /= sum; + double sum = 0.0; + for (int i = 0; i < x.length; i++) { + x[i] = Math.exp(x[i] - max); + sum += x[i]; + } - return x; - } - }; + for (int i = 0; i < x.length; i++) x[i] /= sum; + return x; + } + }; - public abstract double calculateActivation(double x); + public abstract double calculateActivation(double x); - public abstract double calculateDerivative(double x); + public abstract double calculateDerivative(double x); - public double[] calculateActivation(double[] x){ - throw new UnsupportedOperationException("Not implemented"); - } + public double[] calculateActivation(double[] x) { + throw new UnsupportedOperationException("Not implemented"); + } } diff --git a/lib/src/main/java/de/edux/functions/activation/package-info.java b/lib/src/main/java/de/edux/functions/activation/package-info.java index cff50ee..ffc2a95 100644 --- a/lib/src/main/java/de/edux/functions/activation/package-info.java +++ b/lib/src/main/java/de/edux/functions/activation/package-info.java @@ -1,21 +1,22 @@ /** * Provides the classes necessary to define various activation functions used in neural networks. * - *

This package is part of the larger Edux framework for educational purposes in the realm of machine learning. - * Within this package, you will find enumerations and possibly classes that represent a variety of standard - * activation functions, such as Sigmoid, TanH, ReLU, and others. These functions are fundamental components - * in the construction of neural networks, as they dictate how signals are processed as they pass from one - * neuron (or node) to the next, essentially determining the output of each neuron.

+ *

This package is part of the larger Edux framework for educational purposes in the realm of + * machine learning. Within this package, you will find enumerations and possibly classes that + * represent a variety of standard activation functions, such as Sigmoid, TanH, ReLU, and others. + * These functions are fundamental components in the construction of neural networks, as they + * dictate how signals are processed as they pass from one neuron (or node) to the next, essentially + * determining the output of each neuron. * - *

Each activation function contained within this package has distinct characteristics and is useful in - * different scenarios, depending on the nature of the input data, the specific architecture of the network, - * and the learning task at hand. For instance, some functions are better suited for dealing with issues like - * the vanishing gradient problem, while others might normalize input values into a certain range to aid with - * the convergence of the learning algorithm.

- * - *

This package is designed to offer flexibility and ease of use for those constructing machine learning - * models, as it allows for easy switching between different activation strategies, facilitating experimentation - * and learning.

+ *

Each activation function contained within this package has distinct characteristics and is + * useful in different scenarios, depending on the nature of the input data, the specific + * architecture of the network, and the learning task at hand. For instance, some functions are + * better suited for dealing with issues like the vanishing gradient problem, while others might + * normalize input values into a certain range to aid with the convergence of the learning + * algorithm. * + *

This package is designed to offer flexibility and ease of use for those constructing machine + * learning models, as it allows for easy switching between different activation strategies, + * facilitating experimentation and learning. */ package de.edux.functions.activation; diff --git a/lib/src/main/java/de/edux/functions/initialization/Initialization.java b/lib/src/main/java/de/edux/functions/initialization/Initialization.java index a174107..a607ab9 100644 --- a/lib/src/main/java/de/edux/functions/initialization/Initialization.java +++ b/lib/src/main/java/de/edux/functions/initialization/Initialization.java @@ -1,26 +1,26 @@ package de.edux.functions.initialization; public enum Initialization { - XAVIER { - @Override - public double[] weightInitialization(int inputSize, double[] weights) { - double xavier = Math.sqrt(6.0 / (inputSize + 1)); - for (int i = 0; i < weights.length; i++) { - weights[i] = Math.random() * 2 * xavier - xavier; - } - return weights; - } - }, - HE { - @Override - public double[] weightInitialization(int inputSize, double[] weights) { - double he = Math.sqrt(2.0 / inputSize); - for (int i = 0; i < weights.length; i++) { - weights[i] = Math.random() * 2 * he - he; - } - return weights; - } - }; + XAVIER { + @Override + public double[] weightInitialization(int inputSize, double[] weights) { + double xavier = Math.sqrt(6.0 / (inputSize + 1)); + for (int i = 0; i < weights.length; i++) { + weights[i] = Math.random() * 2 * xavier - xavier; + } + return weights; + } + }, + HE { + @Override + public double[] weightInitialization(int inputSize, double[] weights) { + double he = Math.sqrt(2.0 / inputSize); + for (int i = 0; i < weights.length; i++) { + weights[i] = Math.random() * 2 * he - he; + } + return weights; + } + }; - public abstract double[] weightInitialization(int inputSize, double[] weights); + public abstract double[] weightInitialization(int inputSize, double[] weights); } diff --git a/lib/src/main/java/de/edux/functions/loss/LossFunction.java b/lib/src/main/java/de/edux/functions/loss/LossFunction.java index c0d0265..9c9e7b5 100644 --- a/lib/src/main/java/de/edux/functions/loss/LossFunction.java +++ b/lib/src/main/java/de/edux/functions/loss/LossFunction.java @@ -1,68 +1,66 @@ package de.edux.functions.loss; public enum LossFunction { + CATEGORICAL_CROSS_ENTROPY { + @Override + public double calculateError(double[] output, double[] target) { + double error = 0; + for (int i = 0; i < target.length; i++) { + error += target[i] * Math.log(output[i]); + } + return -error; + } + }, + MEAN_SQUARED_ERROR { + @Override + public double calculateError(double[] output, double[] target) { + double error = 0; + for (int i = 0; i < target.length; i++) { + error += Math.pow(target[i] - output[i], 2); + } + return error / target.length; + } + }, + MEAN_ABSOLUTE_ERROR { + @Override + public double calculateError(double[] output, double[] target) { + double error = 0; + for (int i = 0; i < target.length; i++) { + error += Math.abs(target[i] - output[i]); + } + return error / target.length; + } + }, + HINGE_LOSS { + @Override + public double calculateError(double[] output, double[] target) { + double error = 0; + for (int i = 0; i < target.length; i++) { + error += Math.max(0, 1 - target[i] * output[i]); + } + return error / target.length; + } + }, + SQUARED_HINGE_LOSS { + @Override + public double calculateError(double[] output, double[] target) { + double error = 0; + for (int i = 0; i < target.length; i++) { + error += Math.pow(Math.max(0, 1 - target[i] * output[i]), 2); + } + return error / target.length; + } + }, + BINARY_CROSS_ENTROPY { + @Override + public double calculateError(double[] output, double[] target) { + double error = 0; + for (int i = 0; i < target.length; i++) { + error += target[i] * Math.log(output[i]) + (1 - target[i]) * Math.log(1 - output[i]); + } + return -error; + } + }; - CATEGORICAL_CROSS_ENTROPY { - @Override - public double calculateError(double[] output, double[] target) { - double error = 0; - for (int i = 0; i < target.length; i++) { - error += target[i] * Math.log(output[i]); - } - return -error; - } - }, - MEAN_SQUARED_ERROR { - @Override - public double calculateError(double[] output, double[] target) { - double error = 0; - for (int i = 0; i < target.length; i++) { - error += Math.pow(target[i] - output[i], 2); - } - return error / target.length; - } - }, - MEAN_ABSOLUTE_ERROR { - @Override - public double calculateError(double[] output, double[] target) { - double error = 0; - for (int i = 0; i < target.length; i++) { - error += Math.abs(target[i] - output[i]); - } - return error / target.length; - } - }, - HINGE_LOSS { - @Override - public double calculateError(double[] output, double[] target) { - double error = 0; - for (int i = 0; i < target.length; i++) { - error += Math.max(0, 1 - target[i] * output[i]); - } - return error / target.length; - } - }, - SQUARED_HINGE_LOSS { - @Override - public double calculateError(double[] output, double[] target) { - double error = 0; - for (int i = 0; i < target.length; i++) { - error += Math.pow(Math.max(0, 1 - target[i] * output[i]), 2); - } - return error / target.length; - } - }, - BINARY_CROSS_ENTROPY { - @Override - public double calculateError(double[] output, double[] target) { - double error = 0; - for (int i = 0; i < target.length; i++) { - error += target[i] * Math.log(output[i]) + (1 - target[i]) * Math.log(1 - output[i]); - } - return -error; - } - }; - - public abstract double calculateError(double[] target, double[] output); + public abstract double calculateError(double[] target, double[] output); } - diff --git a/lib/src/main/java/de/edux/math/Entity.java b/lib/src/main/java/de/edux/math/Entity.java index 2f1ceb0..04d3794 100644 --- a/lib/src/main/java/de/edux/math/Entity.java +++ b/lib/src/main/java/de/edux/math/Entity.java @@ -2,12 +2,11 @@ public interface Entity { - T add(T another); + T add(T another); - T subtract(T another); + T subtract(T another); - T multiply(T another); - - T scalarMultiply(double n); + T multiply(T another); + T scalarMultiply(double n); } diff --git a/lib/src/main/java/de/edux/math/MathUtil.java b/lib/src/main/java/de/edux/math/MathUtil.java index 162c1fc..d7963d4 100644 --- a/lib/src/main/java/de/edux/math/MathUtil.java +++ b/lib/src/main/java/de/edux/math/MathUtil.java @@ -2,15 +2,14 @@ public final class MathUtil { - public static double[] unwrap(double[][] matrix) { - double[] result = new double[matrix.length * matrix[0].length]; - int i = 0; - for (double[] arr : matrix) { - for (double val : arr) { - result[i++] = val; - } - } - return result; + public static double[] unwrap(double[][] matrix) { + double[] result = new double[matrix.length * matrix[0].length]; + int i = 0; + for (double[] arr : matrix) { + for (double val : arr) { + result[i++] = val; + } } - + return result; + } } diff --git a/lib/src/main/java/de/edux/math/Validations.java b/lib/src/main/java/de/edux/math/Validations.java index 7fbc8e8..076a2e9 100644 --- a/lib/src/main/java/de/edux/math/Validations.java +++ b/lib/src/main/java/de/edux/math/Validations.java @@ -2,16 +2,15 @@ public final class Validations { - public static void size(double[] first, double[] second) { - if (first.length != second.length) { - throw new IllegalArgumentException("sizes mismatch"); - } + public static void size(double[] first, double[] second) { + if (first.length != second.length) { + throw new IllegalArgumentException("sizes mismatch"); } + } - public static void sizeMatrix(double[][] first, double[][] second) { - if (first.length != second.length || first[0].length != second[0].length) { - throw new IllegalArgumentException("sizes mismatch"); - } + public static void sizeMatrix(double[][] first, double[][] second) { + if (first.length != second.length || first[0].length != second[0].length) { + throw new IllegalArgumentException("sizes mismatch"); } - + } } diff --git a/lib/src/main/java/de/edux/math/entity/Matrix.java b/lib/src/main/java/de/edux/math/entity/Matrix.java index c643a39..804aa05 100644 --- a/lib/src/main/java/de/edux/math/entity/Matrix.java +++ b/lib/src/main/java/de/edux/math/entity/Matrix.java @@ -3,159 +3,155 @@ import de.edux.math.Entity; import de.edux.math.MathUtil; import de.edux.math.Validations; - import java.util.Iterator; import java.util.NoSuchElementException; public class Matrix implements Entity, Iterable { - private final double[][] raw; + private final double[][] raw; - public Matrix(double[][] matrix) { - this.raw = matrix; - } + public Matrix(double[][] matrix) { + this.raw = matrix; + } - @Override - public Matrix add(Matrix another) { - return add(another.raw()); - } + @Override + public Matrix add(Matrix another) { + return add(another.raw()); + } - public Matrix add(double[][] another) { - Validations.sizeMatrix(raw, another); + public Matrix add(double[][] another) { + Validations.sizeMatrix(raw, another); - double[][] result = new double[raw.length][raw[0].length]; + double[][] result = new double[raw.length][raw[0].length]; - for (int i = 0; i < result.length; i++) { - for (int a = 0; a < result[0].length; a++) { - result[i][a] = raw[i][a] + another[i][a]; - } - } - - return new Matrix(result); + for (int i = 0; i < result.length; i++) { + for (int a = 0; a < result[0].length; a++) { + result[i][a] = raw[i][a] + another[i][a]; + } } - @Override - public Matrix subtract(Matrix another) { - return subtract(another.raw()); - } + return new Matrix(result); + } - @Override - public Matrix multiply(Matrix another) { - return multiply(another.raw()); - } + @Override + public Matrix subtract(Matrix another) { + return subtract(another.raw()); + } - public Matrix multiply(double[][] another) { - return null; // TODO optimized algorithm for matrix multiplication - } + @Override + public Matrix multiply(Matrix another) { + return multiply(another.raw()); + } - @Override - public Matrix scalarMultiply(double n) { - double[][] result = new double[raw.length][raw[0].length]; + public Matrix multiply(double[][] another) { + return null; // TODO optimized algorithm for matrix multiplication + } - for (int i = 0; i < result.length; i++) { - for (int a = 0; a < result[0].length; a++) { - result[i][a] = raw[i][a] * n; - } - } + @Override + public Matrix scalarMultiply(double n) { + double[][] result = new double[raw.length][raw[0].length]; - return new Matrix(result); + for (int i = 0; i < result.length; i++) { + for (int a = 0; a < result[0].length; a++) { + result[i][a] = raw[i][a] * n; + } } - public Matrix subtract(double[][] another) { - Validations.sizeMatrix(raw, another); + return new Matrix(result); + } - double[][] result = new double[raw.length][raw[0].length]; + public Matrix subtract(double[][] another) { + Validations.sizeMatrix(raw, another); - for (int i = 0; i < result.length; i++) { - for (int a = 0; a < result[0].length; a++) { - result[i][a] = raw[i][a] - another[i][a]; - } - } + double[][] result = new double[raw.length][raw[0].length]; - return new Matrix(result); + for (int i = 0; i < result.length; i++) { + for (int a = 0; a < result[0].length; a++) { + result[i][a] = raw[i][a] - another[i][a]; + } } - public boolean isSquare() { - return rows() == columns(); - } + return new Matrix(result); + } - public int rows() { - return raw.length; - } + public boolean isSquare() { + return rows() == columns(); + } - public int columns() { - return raw[0].length; - } + public int rows() { + return raw.length; + } - public double[][] raw() { - return raw; - } + public int columns() { + return raw[0].length; + } - @Override - public boolean equals(Object obj) { - if (obj instanceof Matrix matrix) { - if (matrix.rows() != rows() || matrix.columns() != columns()) { - return false; - } - for (int i = 0; i < raw.length; i++) { - for (int a = 0; a < raw[i].length; a++) { - if (matrix.raw()[i][a] != raw[i][a]) { - return false; - } - } - } - return true; - } - return false; - } + public double[][] raw() { + return raw; + } - @Override - public String toString() { - StringBuilder builder = new StringBuilder("[").append("\n"); - for (int i = 0; i < raw.length; i++) { - builder.append(" ").append("["); - for (int a = 0; a < raw[i].length; a++) { - builder.append(raw[i][a]); - if (a != raw[i].length - 1) { - builder.append(", "); - } - } - builder.append("]"); - if (i != raw.length - 1) { - builder.append(","); - } - builder.append("\n"); + @Override + public boolean equals(Object obj) { + if (obj instanceof Matrix matrix) { + if (matrix.rows() != rows() || matrix.columns() != columns()) { + return false; + } + for (int i = 0; i < raw.length; i++) { + for (int a = 0; a < raw[i].length; a++) { + if (matrix.raw()[i][a] != raw[i][a]) { + return false; + } } - return builder.append("]").toString(); - } - - @Override - public Iterator iterator() { - return new MatrixIterator(raw); + } + return true; + } + return false; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("[").append("\n"); + for (int i = 0; i < raw.length; i++) { + builder.append(" ").append("["); + for (int a = 0; a < raw[i].length; a++) { + builder.append(raw[i][a]); + if (a != raw[i].length - 1) { + builder.append(", "); + } + } + builder.append("]"); + if (i != raw.length - 1) { + builder.append(","); + } + builder.append("\n"); } + return builder.append("]").toString(); + } - public static class MatrixIterator implements Iterator { + @Override + public Iterator iterator() { + return new MatrixIterator(raw); + } - private final double[] data; - private int current; + public static class MatrixIterator implements Iterator { - public MatrixIterator(double[][] data) { - this.data = MathUtil.unwrap(data); - this.current = 0; - } + private final double[] data; + private int current; - @Override - public boolean hasNext() { - return current < data.length; - } - - @Override - public Double next() { - if (!hasNext()) - throw new NoSuchElementException(); - return data[current++]; - } + public MatrixIterator(double[][] data) { + this.data = MathUtil.unwrap(data); + this.current = 0; + } + @Override + public boolean hasNext() { + return current < data.length; } + @Override + public Double next() { + if (!hasNext()) throw new NoSuchElementException(); + return data[current++]; + } + } } diff --git a/lib/src/main/java/de/edux/math/entity/Vector.java b/lib/src/main/java/de/edux/math/entity/Vector.java index a7cf59d..e19e321 100644 --- a/lib/src/main/java/de/edux/math/entity/Vector.java +++ b/lib/src/main/java/de/edux/math/entity/Vector.java @@ -2,153 +2,149 @@ import de.edux.math.Entity; import de.edux.math.Validations; - import java.util.Arrays; import java.util.Iterator; import java.util.NoSuchElementException; public class Vector implements Entity, Iterable { - private final double[] raw; - - public Vector(double[] vector) { - this.raw = vector; - } + private final double[] raw; - @Override - public Vector add(Vector another) { - return add(another.raw()); - } + public Vector(double[] vector) { + this.raw = vector; + } - public Vector add(double[] another) { - Validations.size(raw, another); + @Override + public Vector add(Vector another) { + return add(another.raw()); + } - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] + another[i]; - } + public Vector add(double[] another) { + Validations.size(raw, another); - return new Vector(result); + double[] result = new double[length()]; + for (int i = 0; i < result.length; i++) { + result[i] = raw[i] + another[i]; } - @Override - public Vector subtract(Vector another) { - return subtract(another.raw()); - } + return new Vector(result); + } - public Vector subtract(double[] another) { - Validations.size(raw, another); + @Override + public Vector subtract(Vector another) { + return subtract(another.raw()); + } - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] - another[i]; - } + public Vector subtract(double[] another) { + Validations.size(raw, another); - return new Vector(result); + double[] result = new double[length()]; + for (int i = 0; i < result.length; i++) { + result[i] = raw[i] - another[i]; } - @Override - public Vector multiply(Vector another) { - return multiply(another.raw()); - } + return new Vector(result); + } - public Vector multiply(double[] another) { - Validations.size(raw, another); + @Override + public Vector multiply(Vector another) { + return multiply(another.raw()); + } - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] * another[i]; - if (result[i] == 0) { // Avoiding -0 result - result[i] = 0; - } - } + public Vector multiply(double[] another) { + Validations.size(raw, another); - return new Vector(result); + double[] result = new double[length()]; + for (int i = 0; i < result.length; i++) { + result[i] = raw[i] * another[i]; + if (result[i] == 0) { // Avoiding -0 result + result[i] = 0; + } } - @Override - public Vector scalarMultiply(double n) { - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] * n; - if (result[i] == 0) { // Avoiding -0 result - result[i] = 0; - } - } - - return new Vector(result); - } + return new Vector(result); + } - public double dot(Vector another) { - return dot(another.raw()); + @Override + public Vector scalarMultiply(double n) { + double[] result = new double[length()]; + for (int i = 0; i < result.length; i++) { + result[i] = raw[i] * n; + if (result[i] == 0) { // Avoiding -0 result + result[i] = 0; + } } - public double dot(double[] another) { - Validations.size(raw, another); + return new Vector(result); + } - double result = 0; - for (int i = 0; i < raw.length; i++) { - result += raw[i] * another[i]; - } + public double dot(Vector another) { + return dot(another.raw()); + } - return result; - } + public double dot(double[] another) { + Validations.size(raw, another); - public int length() { - return raw.length; + double result = 0; + for (int i = 0; i < raw.length; i++) { + result += raw[i] * another[i]; } - public double[] raw() { - return raw.clone(); - } + return result; + } - @Override - public boolean equals(Object obj) { - if (obj instanceof Vector) { - return Arrays.equals(raw, ((Vector) obj).raw()); - } - return false; - } + public int length() { + return raw.length; + } - @Override - public String toString() { - StringBuilder builder = new StringBuilder("["); - for (int i = 0; i < raw.length; i++) { - builder.append(raw[i]); - if (i != raw.length - 1) { - builder.append(", "); - } - } - return builder.append("]").toString(); - } + public double[] raw() { + return raw.clone(); + } - @Override - public Iterator iterator() { - return new VectorIterator(raw); + @Override + public boolean equals(Object obj) { + if (obj instanceof Vector) { + return Arrays.equals(raw, ((Vector) obj).raw()); } + return false; + } - public static class VectorIterator implements Iterator { + @Override + public String toString() { + StringBuilder builder = new StringBuilder("["); + for (int i = 0; i < raw.length; i++) { + builder.append(raw[i]); + if (i != raw.length - 1) { + builder.append(", "); + } + } + return builder.append("]").toString(); + } - private final double[] data; - private int current; + @Override + public Iterator iterator() { + return new VectorIterator(raw); + } - public VectorIterator(double[] data) { - this.data = data; - this.current = 0; - } + public static class VectorIterator implements Iterator { - @Override - public boolean hasNext() { - return current < data.length; - } + private final double[] data; + private int current; - @Override - public Double next() { - if (!hasNext()) - throw new NoSuchElementException(); - return data[current++]; - } + public VectorIterator(double[] data) { + this.data = data; + this.current = 0; + } + @Override + public boolean hasNext() { + return current < data.length; } + @Override + public Double next() { + if (!hasNext()) throw new NoSuchElementException(); + return data[current++]; + } + } } diff --git a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java index 84f04a5..9ca8e5a 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java @@ -1,306 +1,336 @@ package de.edux.ml.decisiontree; import de.edux.api.Classifier; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A Decision Tree classifier for predictive modeling. * - *

The {@code DecisionTree} class is a binary tree where each node represents a decision - * on a particular feature from the input feature vector, effectively partitioning the - * input space into regions with similar output labels. The tree is built recursively - * by selecting splits that minimize the Gini impurity of the resultant partitions. + *

The {@code DecisionTree} class is a binary tree where each node represents a decision on a + * particular feature from the input feature vector, effectively partitioning the input space into + * regions with similar output labels. The tree is built recursively by selecting splits that + * minimize the Gini impurity of the resultant partitions. * *

Features: + * *

    - *
  • Supports binary classification problems.
  • - *
  • Utilizes the Gini impurity to determine optimal feature splits.
  • - *
  • Enables control over tree depth and complexity through various hyperparameters.
  • + *
  • Supports binary classification problems. + *
  • Utilizes the Gini impurity to determine optimal feature splits. + *
  • Enables control over tree depth and complexity through various hyperparameters. *
* *

Hyperparameters include: + * *

    - *
  • {@code maxDepth}: The maximum depth of the tree.
  • - *
  • {@code minSamplesSplit}: The minimum number of samples required to split an internal node.
  • - *
  • {@code minSamplesLeaf}: The minimum number of samples required to be at a leaf node.
  • - *
  • {@code maxLeafNodes}: The maximum number of leaf nodes in the tree.
  • + *
  • {@code maxDepth}: The maximum depth of the tree. + *
  • {@code minSamplesSplit}: The minimum number of samples required to split an internal node. + *
  • {@code minSamplesLeaf}: The minimum number of samples required to be at a leaf node. + *
  • {@code maxLeafNodes}: The maximum number of leaf nodes in the tree. *
* *

Usage example: + * *

{@code
  * DecisionTree classifier = new DecisionTree(10, 2, 1, 50);
  * classifier.train(trainingFeatures, trainingLabels);
  * double accuracy = classifier.evaluate(testFeatures, testLabels);
  * }
* - *

Note: This class requires a thorough validation of input data and parameters, ensuring - * they are never {@code null}, have appropriate dimensions, and adhere to any other - * prerequisites or assumptions, to guarantee robustness and avoid runtime exceptions. + *

Note: This class requires a thorough validation of input data and parameters, ensuring they + * are never {@code null}, have appropriate dimensions, and adhere to any other prerequisites or + * assumptions, to guarantee robustness and avoid runtime exceptions. * * @see Classifier */ public class DecisionTree implements Classifier { - private static final Logger LOG = LoggerFactory.getLogger(DecisionTree.class); - private Node root; - private final int maxDepth; - private final int minSamplesSplit; - private final int minSamplesLeaf; - private final int maxLeafNodes; - private int currentLeafNodes; - - private final Map featureImportances; - - public DecisionTree(int maxDepth, - int minSamplesSplit, - int minSamplesLeaf, - int maxLeafNodes) { - this.maxDepth = maxDepth; - this.minSamplesSplit = minSamplesSplit; - this.minSamplesLeaf = minSamplesLeaf; - this.maxLeafNodes = maxLeafNodes; - this.currentLeafNodes = 0; - this.featureImportances = new HashMap<>(); - } - - @Override - public boolean train(double[][] features, double[][] labels) { - try { - if (features == null || labels == null || features.length == 0 || labels.length == 0 || features.length != labels.length) { - LOG.error("Invalid training data"); - return false; - } + private static final Logger LOG = LoggerFactory.getLogger(DecisionTree.class); + private final int maxDepth; + private final int minSamplesSplit; + private final int minSamplesLeaf; + private final int maxLeafNodes; + private final Map featureImportances; + private Node root; + private int currentLeafNodes; + + public DecisionTree(int maxDepth, int minSamplesSplit, int minSamplesLeaf, int maxLeafNodes) { + this.maxDepth = maxDepth; + this.minSamplesSplit = minSamplesSplit; + this.minSamplesLeaf = minSamplesLeaf; + this.maxLeafNodes = maxLeafNodes; + this.currentLeafNodes = 0; + this.featureImportances = new HashMap<>(); + } + + @Override + public boolean train(double[][] features, double[][] labels) { + try { + if (features == null + || labels == null + || features.length == 0 + || labels.length == 0 + || features.length != labels.length) { + LOG.error("Invalid training data"); + return false; + } - this.root = buildTree(features, labels, 0); + this.root = buildTree(features, labels, 0); - return true; - } catch (Exception e) { - LOG.error("An error occurred during training", e); - return false; - } + return true; + } catch (Exception e) { + LOG.error("An error occurred during training", e); + return false; } + } - private Node buildTree(double[][] features, double[][] labels, int depth) { - Node node = new Node(features); - node.predictedLabel = getMajorityLabel(labels); + private Node buildTree(double[][] features, double[][] labels, int depth) { + Node node = new Node(features); + node.predictedLabel = getMajorityLabel(labels); - if (shouldTerminate(features, depth)) { - currentLeafNodes++; - return node; - } - - SplitResult bestSplit = findBestSplit(features, labels); - if (bestSplit != null) { - applyBestSplit(node, bestSplit, features, labels, depth); - } else { - currentLeafNodes++; - } - - return node; + if (shouldTerminate(features, depth)) { + currentLeafNodes++; + return node; } - private boolean shouldTerminate(double[][] features, int depth) { - boolean maxDepthReached = depth >= maxDepth; - boolean tooFewSamples = features.length < minSamplesSplit; - boolean maxLeafNodesReached = currentLeafNodes >= maxLeafNodes; - - if (maxDepthReached || tooFewSamples || maxLeafNodesReached) { - return true; - } - return false; + SplitResult bestSplit = findBestSplit(features, labels); + if (bestSplit != null) { + applyBestSplit(node, bestSplit, features, labels, depth); + } else { + currentLeafNodes++; } - private SplitResult findBestSplit(double[][] features, double[][] labels) { - double bestGini = Double.MAX_VALUE; - SplitResult bestSplit = null; - - for (int featureIndex = 0; featureIndex < features[0].length; featureIndex++) { - for (double[] feature : features) { - double[][] leftFeatures = filterRows(features, featureIndex, feature[featureIndex], true); - double[][] rightFeatures = filterRows(features, featureIndex, feature[featureIndex], false); - - double[][] leftLabels = filterRows(labels, leftFeatures, features); - double[][] rightLabels = filterRows(labels, rightFeatures, features); + return node; + } - double gini = computeGini(leftLabels, rightLabels); - - if (gini < bestGini) { - bestGini = gini; - updateFeatureImportances(featureIndex, gini); - bestSplit = new SplitResult(featureIndex, feature[featureIndex], leftFeatures, rightFeatures, leftLabels, rightLabels); - } - } - } + private boolean shouldTerminate(double[][] features, int depth) { + boolean maxDepthReached = depth >= maxDepth; + boolean tooFewSamples = features.length < minSamplesSplit; + boolean maxLeafNodesReached = currentLeafNodes >= maxLeafNodes; - return bestSplit; + if (maxDepthReached || tooFewSamples || maxLeafNodesReached) { + return true; } - - private void applyBestSplit(Node node, SplitResult bestSplit, double[][] features, double[][] labels, int depth) { - node.splitFeatureIndex = bestSplit.featureIndex; - node.splitValue = bestSplit.splitValue; - - if (bestSplit.bestLeftFeatures != null && bestSplit.bestRightFeatures != null && - bestSplit.bestLeftFeatures.length >= minSamplesLeaf && bestSplit.bestRightFeatures.length >= minSamplesLeaf) { - - if (currentLeafNodes + 2 <= maxLeafNodes) { - node.left = buildTree(bestSplit.bestLeftFeatures, bestSplit.bestLeftLabels, depth + 1); - node.right = buildTree(bestSplit.bestRightFeatures, bestSplit.bestRightLabels, depth + 1); - currentLeafNodes += 2; - } else { - currentLeafNodes++; - } - } else { - currentLeafNodes++; + return false; + } + + private SplitResult findBestSplit(double[][] features, double[][] labels) { + double bestGini = Double.MAX_VALUE; + SplitResult bestSplit = null; + + for (int featureIndex = 0; featureIndex < features[0].length; featureIndex++) { + for (double[] feature : features) { + double[][] leftFeatures = filterRows(features, featureIndex, feature[featureIndex], true); + double[][] rightFeatures = filterRows(features, featureIndex, feature[featureIndex], false); + + double[][] leftLabels = filterRows(labels, leftFeatures, features); + double[][] rightLabels = filterRows(labels, rightFeatures, features); + + double gini = computeGini(leftLabels, rightLabels); + + if (gini < bestGini) { + bestGini = gini; + updateFeatureImportances(featureIndex, gini); + bestSplit = + new SplitResult( + featureIndex, + feature[featureIndex], + leftFeatures, + rightFeatures, + leftLabels, + rightLabels); } + } } - private void updateFeatureImportances(int featureIndex, double giniReduction) { - featureImportances.merge(featureIndex, giniReduction, Double::sum); - } - - public Map getFeatureImportances() { - double totalImportance = featureImportances.values().stream().mapToDouble(Double::doubleValue).sum(); - return featureImportances.entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - e -> e.getValue() / totalImportance)); + return bestSplit; + } + + private void applyBestSplit( + Node node, SplitResult bestSplit, double[][] features, double[][] labels, int depth) { + node.splitFeatureIndex = bestSplit.featureIndex; + node.splitValue = bestSplit.splitValue; + + if (bestSplit.bestLeftFeatures != null + && bestSplit.bestRightFeatures != null + && bestSplit.bestLeftFeatures.length >= minSamplesLeaf + && bestSplit.bestRightFeatures.length >= minSamplesLeaf) { + + if (currentLeafNodes + 2 <= maxLeafNodes) { + node.left = buildTree(bestSplit.bestLeftFeatures, bestSplit.bestLeftLabels, depth + 1); + node.right = buildTree(bestSplit.bestRightFeatures, bestSplit.bestRightLabels, depth + 1); + currentLeafNodes += 2; + } else { + currentLeafNodes++; + } + } else { + currentLeafNodes++; } - - private double[][] filterRows(double[][] matrix, int featureIndex, double value, boolean lessThan) { - return Arrays.stream(matrix) - .filter(row -> (lessThan && row[featureIndex] < value) || (!lessThan && row[featureIndex] >= value)) - .toArray(double[][]::new); - } - - private double[][] filterRows(double[][] labels, double[][] filteredFeatures, double[][] originalFeatures) { - List filteredLabelsList = new ArrayList<>(); - for (double[] filteredFeature : filteredFeatures) { - for (int i = 0; i < originalFeatures.length; i++) { - if (Arrays.equals(filteredFeature, originalFeatures[i])) { - filteredLabelsList.add(labels[i]); - break; - } - } + } + + private void updateFeatureImportances(int featureIndex, double giniReduction) { + featureImportances.merge(featureIndex, giniReduction, Double::sum); + } + + public Map getFeatureImportances() { + double totalImportance = + featureImportances.values().stream().mapToDouble(Double::doubleValue).sum(); + return featureImportances.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue() / totalImportance)); + } + + private double[][] filterRows( + double[][] matrix, int featureIndex, double value, boolean lessThan) { + return Arrays.stream(matrix) + .filter( + row -> + (lessThan && row[featureIndex] < value) + || (!lessThan && row[featureIndex] >= value)) + .toArray(double[][]::new); + } + + private double[][] filterRows( + double[][] labels, double[][] filteredFeatures, double[][] originalFeatures) { + List filteredLabelsList = new ArrayList<>(); + for (double[] filteredFeature : filteredFeatures) { + for (int i = 0; i < originalFeatures.length; i++) { + if (Arrays.equals(filteredFeature, originalFeatures[i])) { + filteredLabelsList.add(labels[i]); + break; } - return filteredLabelsList.toArray(new double[0][0]); + } + } + return filteredLabelsList.toArray(new double[0][0]); + } + + private double[] getMajorityLabel(double[][] labels) { + return Arrays.stream(labels) + .map(Arrays::toString) // Convert double[] to String for grouping + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) + .entrySet() + .stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .map( + str -> + Arrays.stream(str.substring(1, str.length() - 1).split(", ")) + .mapToDouble(Double::parseDouble) + .toArray()) // Convert String back to double[] + .orElseThrow(RuntimeException::new); + } + + @Override + public double evaluate(double[][] testInputs, double[][] testTargets) { + if (testInputs == null + || testTargets == null + || testInputs.length == 0 + || testTargets.length == 0 + || testInputs.length != testTargets.length) { + LOG.error("Invalid test data"); + return 0; } - private double[] getMajorityLabel(double[][] labels) { - return Arrays.stream(labels) - .map(Arrays::toString) // Convert double[] to String for grouping - .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) - .entrySet().stream() - .max(Map.Entry.comparingByValue()) - .map(Map.Entry::getKey) - .map(str -> Arrays.stream(str.substring(1, str.length() - 1).split(", ")) - .mapToDouble(Double::parseDouble).toArray()) // Convert String back to double[] - .orElseThrow(RuntimeException::new); + long correctPredictions = 0; + for (int i = 0; i < testInputs.length; i++) { + double[] prediction = predict(testInputs[i]); + if (Arrays.equals(prediction, testTargets[i])) { + correctPredictions++; + } } + double accuracy = (double) correctPredictions / testInputs.length; + LOG.info(String.format("Decision Tree - accuracy: %.2f%%", accuracy * 100)); + return accuracy; + } - @Override - public double evaluate(double[][] testInputs, double[][] testTargets) { - if (testInputs == null || testTargets == null || testInputs.length == 0 || testTargets.length == 0 || testInputs.length != testTargets.length) { - LOG.error("Invalid test data"); - return 0; - } + @Override + public double[] predict(double[] feature) { + return predictRecursive(root, feature); + } - long correctPredictions = 0; - for (int i = 0; i < testInputs.length; i++) { - double[] prediction = predict(testInputs[i]); - if (Arrays.equals(prediction, testTargets[i])) { - correctPredictions++; - } - } - - double accuracy = (double) correctPredictions / testInputs.length; - LOG.info(String.format("Decision Tree - accuracy: %.2f%%", accuracy * 100)); - return accuracy; + private double[] predictRecursive(Node node, double[] feature) { + if (node == null || feature == null) { + throw new IllegalArgumentException("Node and feature cannot be null"); } - @Override - public double[] predict(double[] feature) { - return predictRecursive(root, feature); + if (node.left == null && node.right == null) { + return node.predictedLabel; } - private double[] predictRecursive(Node node, double[] feature) { - if (node == null || feature == null) { - throw new IllegalArgumentException("Node and feature cannot be null"); - } - - if (node.left == null && node.right == null) { - return node.predictedLabel; - } - - if (node.splitFeatureIndex >= feature.length) { - throw new IllegalArgumentException("splitFeatureIndex is out of bounds of feature array"); - } - - if (feature[node.splitFeatureIndex] < node.splitValue) { - if (node.left == null) { - throw new IllegalStateException("Left node is null when trying to traverse left"); - } - return predictRecursive(node.left, feature); - } else { - if (node.right == null) { - throw new IllegalStateException("Right node is null when trying to traverse right"); - } - return predictRecursive(node.right, feature); - } + if (node.splitFeatureIndex >= feature.length) { + throw new IllegalArgumentException("splitFeatureIndex is out of bounds of feature array"); } - private double computeGini(double[][] leftLabels, double[][] rightLabels) { - double leftImpurity = computeImpurity(leftLabels); - double rightImpurity = computeImpurity(rightLabels); - double leftWeight = ((double) leftLabels.length) / (leftLabels.length + rightLabels.length); - double rightWeight = ((double) rightLabels.length) / (leftLabels.length + rightLabels.length); - return leftWeight * leftImpurity + rightWeight * rightImpurity; + if (feature[node.splitFeatureIndex] < node.splitValue) { + if (node.left == null) { + throw new IllegalStateException("Left node is null when trying to traverse left"); + } + return predictRecursive(node.left, feature); + } else { + if (node.right == null) { + throw new IllegalStateException("Right node is null when trying to traverse right"); + } + return predictRecursive(node.right, feature); } - - private double computeImpurity(double[][] labels) { - double impurity = 1.0; - Map labelCounts = Arrays.stream(labels) - .map(Arrays::toString) - .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); - for (Long count : labelCounts.values()) { - double p = ((double) count) / labels.length; - impurity -= p * p; - } - return impurity; + } + + private double computeGini(double[][] leftLabels, double[][] rightLabels) { + double leftImpurity = computeImpurity(leftLabels); + double rightImpurity = computeImpurity(rightLabels); + double leftWeight = ((double) leftLabels.length) / (leftLabels.length + rightLabels.length); + double rightWeight = ((double) rightLabels.length) / (leftLabels.length + rightLabels.length); + return leftWeight * leftImpurity + rightWeight * rightImpurity; + } + + private double computeImpurity(double[][] labels) { + double impurity = 1.0; + Map labelCounts = + Arrays.stream(labels) + .map(Arrays::toString) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + for (Long count : labelCounts.values()) { + double p = ((double) count) / labels.length; + impurity -= p * p; } - - private static class SplitResult { - int featureIndex; - double splitValue; - double[][] bestLeftFeatures; - double[][] bestRightFeatures; - double[][] bestLeftLabels; - double[][] bestRightLabels; - - SplitResult(int featureIndex, double splitValue, double[][] bestLeftFeatures, double[][] bestRightFeatures, - double[][] bestLeftLabels, double[][] bestRightLabels) { - this.featureIndex = featureIndex; - this.splitValue = splitValue; - this.bestLeftFeatures = bestLeftFeatures; - this.bestRightFeatures = bestRightFeatures; - this.bestLeftLabels = bestLeftLabels; - this.bestRightLabels = bestRightLabels; - } + return impurity; + } + + private static class SplitResult { + int featureIndex; + double splitValue; + double[][] bestLeftFeatures; + double[][] bestRightFeatures; + double[][] bestLeftLabels; + double[][] bestRightLabels; + + SplitResult( + int featureIndex, + double splitValue, + double[][] bestLeftFeatures, + double[][] bestRightFeatures, + double[][] bestLeftLabels, + double[][] bestRightLabels) { + this.featureIndex = featureIndex; + this.splitValue = splitValue; + this.bestLeftFeatures = bestLeftFeatures; + this.bestRightFeatures = bestRightFeatures; + this.bestLeftLabels = bestLeftLabels; + this.bestRightLabels = bestRightLabels; } - private static class Node { - double[][] data; - Node left; - Node right; - int splitFeatureIndex; - double splitValue; - double[] predictedLabel; - - public Node(double[][] data) { - this.data = data; - } + } + + private static class Node { + double[][] data; + Node left; + Node right; + int splitFeatureIndex; + double splitValue; + double[] predictedLabel; + + public Node(double[][] data) { + this.data = data; } -} \ No newline at end of file + } +} diff --git a/lib/src/main/java/de/edux/ml/decisiontree/Node.java b/lib/src/main/java/de/edux/ml/decisiontree/Node.java index 2f2a636..6702e33 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/Node.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/Node.java @@ -1,15 +1,15 @@ package de.edux.ml.decisiontree; class Node { - double[][] data; - double value; - int splitFeature; - Node left; - Node right; - boolean isLeaf; + double[][] data; + double value; + int splitFeature; + Node left; + Node right; + boolean isLeaf; - Node(double[][] data) { - this.data = data; - isLeaf = false; - } -} \ No newline at end of file + Node(double[][] data) { + this.data = data; + isLeaf = false; + } +} diff --git a/lib/src/main/java/de/edux/ml/decisiontree/package-info.java b/lib/src/main/java/de/edux/ml/decisiontree/package-info.java index 40d7258..997416c 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/package-info.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/package-info.java @@ -1,4 +1,2 @@ -/** - * Decision tree implementation. - */ +/** Decision tree implementation. */ package de.edux.ml.decisiontree; diff --git a/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java b/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java index 56f60e6..cf34664 100644 --- a/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java +++ b/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java @@ -1,19 +1,20 @@ package de.edux.ml.knn; import de.edux.api.Classifier; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.Arrays; import java.util.PriorityQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * The {@code KnnClassifier} class provides an implementation of the k-Nearest Neighbors algorithm for classification tasks. - * It stores the training dataset and predicts the label for new data points based on the majority label of its k-nearest neighbors in the feature space. - * Distance between data points is computed using the Euclidean distance metric. - * Optionally, predictions can be weighted by the inverse of the distance to give closer neighbors higher influence. + * The {@code KnnClassifier} class provides an implementation of the k-Nearest Neighbors algorithm + * for classification tasks. It stores the training dataset and predicts the label for new data + * points based on the majority label of its k-nearest neighbors in the feature space. Distance + * between data points is computed using the Euclidean distance metric. Optionally, predictions can + * be weighted by the inverse of the distance to give closer neighbors higher influence. + * + *

Example usage: * - *

Example usage:

*
{@code
  * int k = 3;  // Specify the number of neighbors to consider
  * KnnClassifier knn = new KnnClassifier(k);
@@ -23,97 +24,97 @@
  * double accuracy = knn.evaluate(testFeatures, testLabels);
  * }
* - *

Note: The label arrays should be in one-hot encoding format.

- * - * + *

Note: The label arrays should be in one-hot encoding format. */ public class KnnClassifier implements Classifier { - Logger LOG = LoggerFactory.getLogger(KnnClassifier.class); - private double[][] trainFeatures; - private double[][] trainLabels; - private int k; - private static final double EPSILON = 1e-10; + private static final double EPSILON = 1e-10; + Logger LOG = LoggerFactory.getLogger(KnnClassifier.class); + private double[][] trainFeatures; + private double[][] trainLabels; + private int k; - /** - * Initializes a new instance of {@code KnnClassifier} with specified k. - * - * @param k an integer value representing the number of neighbors to consider during classification - * @throws IllegalArgumentException if k is not a positive integer - */ - public KnnClassifier(int k) { - if (k <= 0) { - throw new IllegalArgumentException("k must be a positive integer"); - } - this.k = k; + /** + * Initializes a new instance of {@code KnnClassifier} with specified k. + * + * @param k an integer value representing the number of neighbors to consider during + * classification + * @throws IllegalArgumentException if k is not a positive integer + */ + public KnnClassifier(int k) { + if (k <= 0) { + throw new IllegalArgumentException("k must be a positive integer"); } + this.k = k; + } - @Override - public boolean train(double[][] features, double[][] labels) { - if (features.length == 0 || features.length != labels.length) { - return false; - } - this.trainFeatures = features; - this.trainLabels = labels; - return true; + @Override + public boolean train(double[][] features, double[][] labels) { + if (features.length == 0 || features.length != labels.length) { + return false; } + this.trainFeatures = features; + this.trainLabels = labels; + return true; + } - @Override - public double evaluate(double[][] testInputs, double[][] testTargets) { - LOG.info("Evaluating..."); - int correct = 0; - for (int i = 0; i < testInputs.length; i++) { - double[] prediction = predict(testInputs[i]); - if (Arrays.equals(prediction, testTargets[i])) { - correct++; - } - } - double accuracy = (double) correct / testInputs.length; - LOG.info("KNN - Accuracy: " + accuracy * 100 + "%"); - return accuracy; + @Override + public double evaluate(double[][] testInputs, double[][] testTargets) { + LOG.info("Evaluating..."); + int correct = 0; + for (int i = 0; i < testInputs.length; i++) { + double[] prediction = predict(testInputs[i]); + if (Arrays.equals(prediction, testTargets[i])) { + correct++; + } } + double accuracy = (double) correct / testInputs.length; + LOG.info("KNN - Accuracy: " + accuracy * 100 + "%"); + return accuracy; + } - @Override - public double[] predict(double[] feature) { - PriorityQueue pq = new PriorityQueue<>((a, b) -> Double.compare(b.distance, a.distance)); - for (int i = 0; i < trainFeatures.length; i++) { - double distance = calculateDistance(trainFeatures[i], feature); - pq.offer(new Neighbor(distance, trainLabels[i])); - if (pq.size() > k) { - pq.poll(); - } - } + @Override + public double[] predict(double[] feature) { + PriorityQueue pq = + new PriorityQueue<>((a, b) -> Double.compare(b.distance, a.distance)); + for (int i = 0; i < trainFeatures.length; i++) { + double distance = calculateDistance(trainFeatures[i], feature); + pq.offer(new Neighbor(distance, trainLabels[i])); + if (pq.size() > k) { + pq.poll(); + } + } - double[] aggregatedLabel = new double[trainLabels[0].length]; - double totalWeight = 0; - for (Neighbor neighbor : pq) { - double weight = 1 / (neighbor.distance + EPSILON); - for (int i = 0; i < aggregatedLabel.length; i++) { - aggregatedLabel[i] += neighbor.label[i] * weight; - } - totalWeight += weight; - } + double[] aggregatedLabel = new double[trainLabels[0].length]; + double totalWeight = 0; + for (Neighbor neighbor : pq) { + double weight = 1 / (neighbor.distance + EPSILON); + for (int i = 0; i < aggregatedLabel.length; i++) { + aggregatedLabel[i] += neighbor.label[i] * weight; + } + totalWeight += weight; + } - for (int i = 0; i < aggregatedLabel.length; i++) { - aggregatedLabel[i] /= totalWeight; - } - return aggregatedLabel; + for (int i = 0; i < aggregatedLabel.length; i++) { + aggregatedLabel[i] /= totalWeight; } + return aggregatedLabel; + } - private double calculateDistance(double[] a, double[] b) { - double sum = 0; - for (int i = 0; i < a.length; i++) { - sum += Math.pow(a[i] - b[i], 2); - } - return Math.sqrt(sum); + private double calculateDistance(double[] a, double[] b) { + double sum = 0; + for (int i = 0; i < a.length; i++) { + sum += Math.pow(a[i] - b[i], 2); } + return Math.sqrt(sum); + } - private static class Neighbor { - private double distance; - private double[] label; + private static class Neighbor { + private double distance; + private double[] label; - public Neighbor(double distance, double[] label) { - this.distance = distance; - this.label = label; - } + public Neighbor(double distance, double[] label) { + this.distance = distance; + this.label = label; } + } } diff --git a/lib/src/main/java/de/edux/ml/nn/config/NetworkConfiguration.java b/lib/src/main/java/de/edux/ml/nn/config/NetworkConfiguration.java index ae5c05e..4864d01 100644 --- a/lib/src/main/java/de/edux/ml/nn/config/NetworkConfiguration.java +++ b/lib/src/main/java/de/edux/ml/nn/config/NetworkConfiguration.java @@ -3,10 +3,16 @@ import de.edux.functions.activation.ActivationFunction; import de.edux.functions.initialization.Initialization; import de.edux.functions.loss.LossFunction; - import java.util.List; -public record NetworkConfiguration(int inputSize, List hiddenLayersSize, int outputSize, double learningRate, int epochs, - ActivationFunction hiddenLayerActivationFunction, ActivationFunction outputLayerActivationFunction, LossFunction lossFunction, Initialization hiddenLayerWeightInitialization, Initialization outputLayerWeightInitialization) { - -} \ No newline at end of file +public record NetworkConfiguration( + int inputSize, + List hiddenLayersSize, + int outputSize, + double learningRate, + int epochs, + ActivationFunction hiddenLayerActivationFunction, + ActivationFunction outputLayerActivationFunction, + LossFunction lossFunction, + Initialization hiddenLayerWeightInitialization, + Initialization outputLayerWeightInitialization) {} diff --git a/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java b/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java index e51645f..3627514 100644 --- a/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java +++ b/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java @@ -3,30 +3,30 @@ import de.edux.api.Classifier; import de.edux.functions.activation.ActivationFunction; import de.edux.ml.nn.config.NetworkConfiguration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * The {@code MultilayerPerceptron} class represents a simple feedforward neural network, - * which consists of input, hidden, and output layers. It implements the {@code Classifier} - * interface, facilitating both the training and prediction processes on a given dataset. + * The {@code MultilayerPerceptron} class represents a simple feedforward neural network, which + * consists of input, hidden, and output layers. It implements the {@code Classifier} interface, + * facilitating both the training and prediction processes on a given dataset. * - *

This implementation utilizes a backpropagation algorithm for training the neural network - * to adjust weights and biases, considering a set configuration defined by {@link NetworkConfiguration}. - * The network's architecture is multi-layered, comprising one or more hidden layers in addition - * to the input and output layers. Neurons within these layers utilize activation functions defined - * per layer through the configuration.

+ *

This implementation utilizes a backpropagation algorithm for training the neural network to + * adjust weights and biases, considering a set configuration defined by {@link + * NetworkConfiguration}. The network's architecture is multi-layered, comprising one or more hidden + * layers in addition to the input and output layers. Neurons within these layers utilize activation + * functions defined per layer through the configuration. * - *

The training process adjusts the weights and biases of neurons within the network based on - * the error between predicted and expected outputs. Additionally, the implementation provides functionality - * to save and restore the best model achieved during training based on accuracy. Early stopping is applied - * during training to prevent overfitting and unnecessary computational expense by monitoring the performance - * improvement across epochs.

+ *

The training process adjusts the weights and biases of neurons within the network based on the + * error between predicted and expected outputs. Additionally, the implementation provides + * functionality to save and restore the best model achieved during training based on accuracy. + * Early stopping is applied during training to prevent overfitting and unnecessary computational + * expense by monitoring the performance improvement across epochs. + * + *

Usage example: * - *

Usage example:

*
  *    NetworkConfiguration config = ... ;
  *    double[][] testFeatures = ... ;
@@ -39,7 +39,8 @@
  *    double[] prediction = mlp.predict(singleInput);
  * 
* - *

Note: This implementation logs informative messages, such as accuracy per epoch, using SLF4J logging.

+ *

Note: This implementation logs informative messages, such as accuracy per epoch, using SLF4J + * logging. * * @see de.edux.api.Classifier * @see de.edux.ml.nn.network.Neuron @@ -47,203 +48,226 @@ * @see de.edux.functions.activation.ActivationFunction */ public class MultilayerPerceptron implements Classifier { - private static final Logger LOG = LoggerFactory.getLogger(MultilayerPerceptron.class); - - private final NetworkConfiguration config; - private final ActivationFunction hiddenLayerActivationFunction; - private final ActivationFunction outputLayerActivationFunction; - private List hiddenLayers; - private Neuron[] outputLayer; - private final double[][] testFeatures; - private final double[][] testLabels; - private double bestAccuracy; - private ArrayList bestHiddenLayers; - private Neuron[] bestOutputLayer; - - public MultilayerPerceptron(NetworkConfiguration config, double[][] testFeatures, double[][] testLabels) { - this.config = config; - this.testFeatures = testFeatures; - this.testLabels = testLabels; - - hiddenLayerActivationFunction = config.hiddenLayerActivationFunction(); - outputLayerActivationFunction = config.outputLayerActivationFunction(); - - hiddenLayers = new ArrayList<>(); - - int inputSizeForCurrentLayer = config.inputSize(); - for (int layerSize : config.hiddenLayersSize()) { - Neuron[] hiddenLayer = new Neuron[layerSize]; - for (int i = 0; i < layerSize; i++) { - hiddenLayer[i] = new Neuron(inputSizeForCurrentLayer, hiddenLayerActivationFunction, this.config.hiddenLayerWeightInitialization()); - } - hiddenLayers.add(hiddenLayer); - inputSizeForCurrentLayer = layerSize; - } + private static final Logger LOG = LoggerFactory.getLogger(MultilayerPerceptron.class); + + private final NetworkConfiguration config; + private final ActivationFunction hiddenLayerActivationFunction; + private final ActivationFunction outputLayerActivationFunction; + private final double[][] testFeatures; + private final double[][] testLabels; + private List hiddenLayers; + private Neuron[] outputLayer; + private double bestAccuracy; + private ArrayList bestHiddenLayers; + private Neuron[] bestOutputLayer; + + public MultilayerPerceptron( + NetworkConfiguration config, double[][] testFeatures, double[][] testLabels) { + this.config = config; + this.testFeatures = testFeatures; + this.testLabels = testLabels; + + hiddenLayerActivationFunction = config.hiddenLayerActivationFunction(); + outputLayerActivationFunction = config.outputLayerActivationFunction(); + + hiddenLayers = new ArrayList<>(); + + int inputSizeForCurrentLayer = config.inputSize(); + for (int layerSize : config.hiddenLayersSize()) { + Neuron[] hiddenLayer = new Neuron[layerSize]; + for (int i = 0; i < layerSize; i++) { + hiddenLayer[i] = + new Neuron( + inputSizeForCurrentLayer, + hiddenLayerActivationFunction, + this.config.hiddenLayerWeightInitialization()); + } + hiddenLayers.add(hiddenLayer); + inputSizeForCurrentLayer = layerSize; + } - outputLayer = new Neuron[config.outputSize()]; - for (int i = 0; i < config.outputSize(); i++) { - outputLayer[i] = new Neuron(inputSizeForCurrentLayer, outputLayerActivationFunction, this.config.outputLayerWeightInitialization()); - } + outputLayer = new Neuron[config.outputSize()]; + for (int i = 0; i < config.outputSize(); i++) { + outputLayer[i] = + new Neuron( + inputSizeForCurrentLayer, + outputLayerActivationFunction, + this.config.outputLayerWeightInitialization()); } + } - private double[] feedforward(double[] input) { + private double[] feedforward(double[] input) { - double[] currentInput = passInputTroughAllHiddenLayers(input); + double[] currentInput = passInputTroughAllHiddenLayers(input); - double[] output = passInputTroughOutputLayer(currentInput); + double[] output = passInputTroughOutputLayer(currentInput); - return outputLayerActivationFunction.calculateActivation(output); - } + return outputLayerActivationFunction.calculateActivation(output); + } - private double[] passInputTroughAllHiddenLayers(double[] input) { - double[] currentInput = input; - for (Neuron[] layer : hiddenLayers) { - double[] hiddenOutputs = new double[layer.length]; - for (int i = 0; i < layer.length; i++) { - hiddenOutputs[i] = layer[i].calculateOutput(currentInput); - } - currentInput = hiddenOutputs; - } - return currentInput; + private double[] passInputTroughAllHiddenLayers(double[] input) { + double[] currentInput = input; + for (Neuron[] layer : hiddenLayers) { + double[] hiddenOutputs = new double[layer.length]; + for (int i = 0; i < layer.length; i++) { + hiddenOutputs[i] = layer[i].calculateOutput(currentInput); + } + currentInput = hiddenOutputs; } + return currentInput; + } - private double[] passInputTroughOutputLayer(double[] currentInput) { - double[] output = new double[config.outputSize()]; - for (int i = 0; i < config.outputSize(); i++) { - output[i] = outputLayer[i].calculateOutput(currentInput); - } - return output; + private double[] passInputTroughOutputLayer(double[] currentInput) { + double[] output = new double[config.outputSize()]; + for (int i = 0; i < config.outputSize(); i++) { + output[i] = outputLayer[i].calculateOutput(currentInput); } + return output; + } - @Override - public boolean train(double[][] features, double[][] labels) { - bestAccuracy = 0; - int epochsWithoutImprovement = 0; - final int PATIENCE = 10; - - for (int epoch = 0; epoch < config.epochs(); epoch++) { - for (int i = 0; i < features.length; i++) { - double[] output = feedforward(features[i]); - - double[] output_error_signal = new double[config.outputSize()]; - for (int j = 0; j < config.outputSize(); j++) { - output_error_signal[j] = labels[i][j] - output[j]; - } - - List hidden_error_signals = new ArrayList<>(); - for (int j = hiddenLayers.size() - 1; j >= 0; j--) { - double[] hidden_error_signal = new double[hiddenLayers.get(j).length]; - for (int k = 0; k < hiddenLayers.get(j).length; k++) { - for (int l = 0; l < output_error_signal.length; l++) { - hidden_error_signal[k] += output_error_signal[l] * (j == hiddenLayers.size() - 1 ? outputLayer[l].getWeight(k) : hiddenLayers.get(j + 1)[l].getWeight(k)); - } - } - hidden_error_signals.add(0, hidden_error_signal); - output_error_signal = hidden_error_signal; - } - - - updateWeights(i, output_error_signal, hidden_error_signals, features); - } + @Override + public boolean train(double[][] features, double[][] labels) { + bestAccuracy = 0; + int epochsWithoutImprovement = 0; + final int PATIENCE = 10; - double accuracy = evaluate(testFeatures, testLabels); - LOG.info("Epoch: {} - Accuracy: {}%", epoch, String.format("%.2f", accuracy * 100)); + for (int epoch = 0; epoch < config.epochs(); epoch++) { + for (int i = 0; i < features.length; i++) { + double[] output = feedforward(features[i]); - if (accuracy > bestAccuracy) { - bestAccuracy = accuracy; - epochsWithoutImprovement = 0; - saveBestModel(hiddenLayers, outputLayer); - } else { - epochsWithoutImprovement++; - } - - if (epochsWithoutImprovement >= PATIENCE) { - LOG.info("Early stopping: Stopping training as the model has not improved in the last {} epochs.", PATIENCE); - loadBestModel(); - LOG.info("Best accuracy after restoring best MLP model: {}%", String.format("%.2f", bestAccuracy * 100)); - break; - } + double[] output_error_signal = new double[config.outputSize()]; + for (int j = 0; j < config.outputSize(); j++) { + output_error_signal[j] = labels[i][j] - output[j]; } - return true; - } - - private void loadBestModel() { - this.hiddenLayers = this.bestHiddenLayers; - this.outputLayer = this.bestOutputLayer; - } - private void saveBestModel(List hiddenLayers, Neuron[] outputLayer) { - this.bestHiddenLayers = new ArrayList<>(); - this.bestOutputLayer = new Neuron[outputLayer.length]; - for (int i = 0; i < hiddenLayers.size(); i++) { - Neuron[] layer = hiddenLayers.get(i); - Neuron[] newLayer = new Neuron[layer.length]; - for (int j = 0; j < layer.length; j++) { - newLayer[j] = new Neuron(layer[j].getWeights().length, layer[j].getActivationFunction(), layer[j].getInitialization()); - newLayer[j].setBias(layer[j].getBias()); - for (int k = 0; k < layer[j].getWeights().length; k++) { - newLayer[j].getWeights()[k] = layer[j].getWeight(k); - } - } - this.bestHiddenLayers.add(newLayer); - } - for (int i = 0; i < outputLayer.length; i++) { - this.bestOutputLayer[i] = new Neuron(outputLayer[i].getWeights().length, outputLayer[i].getActivationFunction(), outputLayer[i].getInitialization()); - this.bestOutputLayer[i].setBias(outputLayer[i].getBias()); - for (int j = 0; j < outputLayer[i].getWeights().length; j++) { - this.bestOutputLayer[i].getWeights()[j] = outputLayer[i].getWeight(j); + List hidden_error_signals = new ArrayList<>(); + for (int j = hiddenLayers.size() - 1; j >= 0; j--) { + double[] hidden_error_signal = new double[hiddenLayers.get(j).length]; + for (int k = 0; k < hiddenLayers.get(j).length; k++) { + for (int l = 0; l < output_error_signal.length; l++) { + hidden_error_signal[k] += + output_error_signal[l] + * (j == hiddenLayers.size() - 1 + ? outputLayer[l].getWeight(k) + : hiddenLayers.get(j + 1)[l].getWeight(k)); } + } + hidden_error_signals.add(0, hidden_error_signal); + output_error_signal = hidden_error_signal; } + updateWeights(i, output_error_signal, hidden_error_signals, features); + } + + double accuracy = evaluate(testFeatures, testLabels); + LOG.info("Epoch: {} - Accuracy: {}%", epoch, String.format("%.2f", accuracy * 100)); + + if (accuracy > bestAccuracy) { + bestAccuracy = accuracy; + epochsWithoutImprovement = 0; + saveBestModel(hiddenLayers, outputLayer); + } else { + epochsWithoutImprovement++; + } + + if (epochsWithoutImprovement >= PATIENCE) { + LOG.info( + "Early stopping: Stopping training as the model has not improved in the last {} epochs.", + PATIENCE); + loadBestModel(); + LOG.info( + "Best accuracy after restoring best MLP model: {}%", + String.format("%.2f", bestAccuracy * 100)); + break; + } } - - private void updateWeights(int i, double[] output_error_signal, List hidden_error_signals, double[][] features) { - double[] currentInput = features[i]; - - for (int j = 0; j < hiddenLayers.size(); j++) { - Neuron[] layer = hiddenLayers.get(j); - double[] errorSignal = hidden_error_signals.get(j); - for (int k = 0; k < layer.length; k++) { - layer[k].adjustBias(errorSignal[k], config.learningRate()); - layer[k].adjustWeights(currentInput, errorSignal[k], config.learningRate()); - } - currentInput = new double[layer.length]; - for (int k = 0; k < layer.length; k++) { - currentInput[k] = layer[k].calculateOutput(features[i]); - } - } - - for (int j = 0; j < config.outputSize(); j++) { - outputLayer[j].adjustBias(output_error_signal[j], config.learningRate()); - outputLayer[j].adjustWeights(currentInput, output_error_signal[j], config.learningRate()); + return true; + } + + private void loadBestModel() { + this.hiddenLayers = this.bestHiddenLayers; + this.outputLayer = this.bestOutputLayer; + } + + private void saveBestModel(List hiddenLayers, Neuron[] outputLayer) { + this.bestHiddenLayers = new ArrayList<>(); + this.bestOutputLayer = new Neuron[outputLayer.length]; + for (int i = 0; i < hiddenLayers.size(); i++) { + Neuron[] layer = hiddenLayers.get(i); + Neuron[] newLayer = new Neuron[layer.length]; + for (int j = 0; j < layer.length; j++) { + newLayer[j] = + new Neuron( + layer[j].getWeights().length, + layer[j].getActivationFunction(), + layer[j].getInitialization()); + newLayer[j].setBias(layer[j].getBias()); + for (int k = 0; k < layer[j].getWeights().length; k++) { + newLayer[j].getWeights()[k] = layer[j].getWeight(k); } + } + this.bestHiddenLayers.add(newLayer); + } + for (int i = 0; i < outputLayer.length; i++) { + this.bestOutputLayer[i] = + new Neuron( + outputLayer[i].getWeights().length, + outputLayer[i].getActivationFunction(), + outputLayer[i].getInitialization()); + this.bestOutputLayer[i].setBias(outputLayer[i].getBias()); + for (int j = 0; j < outputLayer[i].getWeights().length; j++) { + this.bestOutputLayer[i].getWeights()[j] = outputLayer[i].getWeight(j); + } + } + } + + private void updateWeights( + int i, + double[] output_error_signal, + List hidden_error_signals, + double[][] features) { + double[] currentInput = features[i]; + + for (int j = 0; j < hiddenLayers.size(); j++) { + Neuron[] layer = hiddenLayers.get(j); + double[] errorSignal = hidden_error_signals.get(j); + for (int k = 0; k < layer.length; k++) { + layer[k].adjustBias(errorSignal[k], config.learningRate()); + layer[k].adjustWeights(currentInput, errorSignal[k], config.learningRate()); + } + currentInput = new double[layer.length]; + for (int k = 0; k < layer.length; k++) { + currentInput[k] = layer[k].calculateOutput(features[i]); + } } - @Override - public double evaluate(double[][] testInputs, double[][] testTargets) { - int correctCount = 0; + for (int j = 0; j < config.outputSize(); j++) { + outputLayer[j].adjustBias(output_error_signal[j], config.learningRate()); + outputLayer[j].adjustWeights(currentInput, output_error_signal[j], config.learningRate()); + } + } - for (int i = 0; i < testInputs.length; i++) { - double[] predicted = predict(testInputs[i]); - int predictedIndex = 0; - int targetIndex = 0; + @Override + public double evaluate(double[][] testInputs, double[][] testTargets) { + int correctCount = 0; - for (int j = 0; j < predicted.length; j++) { - if (predicted[j] > predicted[predictedIndex]) - predictedIndex = j; - if (testTargets[i][j] > testTargets[i][targetIndex]) - targetIndex = j; - } + for (int i = 0; i < testInputs.length; i++) { + double[] predicted = predict(testInputs[i]); + int predictedIndex = 0; + int targetIndex = 0; - if (predictedIndex == targetIndex) - correctCount++; - } + for (int j = 0; j < predicted.length; j++) { + if (predicted[j] > predicted[predictedIndex]) predictedIndex = j; + if (testTargets[i][j] > testTargets[i][targetIndex]) targetIndex = j; + } - return (double) correctCount / testInputs.length; + if (predictedIndex == targetIndex) correctCount++; } - public double[] predict(double[] input) { - return feedforward(input); - } + return (double) correctCount / testInputs.length; + } + public double[] predict(double[] input) { + return feedforward(input); + } } diff --git a/lib/src/main/java/de/edux/ml/nn/network/Neuron.java b/lib/src/main/java/de/edux/ml/nn/network/Neuron.java index 05d8d72..6fdb4bc 100644 --- a/lib/src/main/java/de/edux/ml/nn/network/Neuron.java +++ b/lib/src/main/java/de/edux/ml/nn/network/Neuron.java @@ -4,63 +4,63 @@ import de.edux.functions.initialization.Initialization; class Neuron { - private final Initialization initialization; - private double[] weights; - private double bias; - private final ActivationFunction activationFunction; + private final Initialization initialization; + private final ActivationFunction activationFunction; + private double[] weights; + private double bias; - public Neuron(int inputSize, ActivationFunction activationFunction, Initialization initialization) { - this.weights = new double[inputSize]; - this.activationFunction = activationFunction; - this.initialization = initialization; - this.bias = initialization.weightInitialization(inputSize, new double[1])[0]; - this.weights = initialization.weightInitialization(inputSize, weights); + public Neuron( + int inputSize, ActivationFunction activationFunction, Initialization initialization) { + this.weights = new double[inputSize]; + this.activationFunction = activationFunction; + this.initialization = initialization; + this.bias = initialization.weightInitialization(inputSize, new double[1])[0]; + this.weights = initialization.weightInitialization(inputSize, weights); + } - } + public Initialization getInitialization() { + return initialization; + } - public Initialization getInitialization() { - return initialization; + public double calculateOutput(double[] input) { + double output = bias; + for (int i = 0; i < input.length; i++) { + output += input[i] * weights[i]; } + return activationFunction.calculateActivation(output); + } - public double calculateOutput(double[] input) { - double output = bias; - for (int i = 0; i < input.length; i++) { - output += input[i] * weights[i]; - } - return activationFunction.calculateActivation(output); + public void adjustWeights(double[] input, double error, double learningRate) { + for (int i = 0; i < weights.length; i++) { + weights[i] += learningRate * input[i] * error; } + } - public void adjustWeights(double[] input, double error, double learningRate) { - for (int i = 0; i < weights.length; i++) { - weights[i] += learningRate * input[i] * error; - } - } + public void adjustBias(double error, double learningRate) { + bias += learningRate * error; + } - public void adjustBias(double error, double learningRate) { - bias += learningRate * error; - } - - public double getWeight(int index) { - return weights[index]; - } + public double getWeight(int index) { + return weights[index]; + } - public double[] getWeights() { - return weights; - } + public double[] getWeights() { + return weights; + } - public double getBias() { - return bias; - } + public void setWeights(double[] weights) { + this.weights = weights; + } - public ActivationFunction getActivationFunction() { - return activationFunction; - } + public double getBias() { + return bias; + } - public void setWeights(double[] weights) { - this.weights = weights; - } + public void setBias(double bias) { + this.bias = bias; + } - public void setBias(double bias) { - this.bias = bias; - } + public ActivationFunction getActivationFunction() { + return activationFunction; + } } diff --git a/lib/src/main/java/de/edux/ml/nn/network/api/INeuron.java b/lib/src/main/java/de/edux/ml/nn/network/api/INeuron.java index 2ec2e39..88c63aa 100644 --- a/lib/src/main/java/de/edux/ml/nn/network/api/INeuron.java +++ b/lib/src/main/java/de/edux/ml/nn/network/api/INeuron.java @@ -1,10 +1,9 @@ package de.edux.ml.nn.network.api; public interface INeuron { - double calculateOutput(double[] inputs); + double calculateOutput(double[] inputs); - double calculateError(double targetOutput); + double calculateError(double targetOutput); - void updateWeights(double[] inputs, double error); - -} \ No newline at end of file + void updateWeights(double[] inputs, double error); +} diff --git a/lib/src/main/java/de/edux/ml/nn/network/api/IPerceptron.java b/lib/src/main/java/de/edux/ml/nn/network/api/IPerceptron.java index 2f474cb..486fcb1 100644 --- a/lib/src/main/java/de/edux/ml/nn/network/api/IPerceptron.java +++ b/lib/src/main/java/de/edux/ml/nn/network/api/IPerceptron.java @@ -1,11 +1,11 @@ package de.edux.ml.nn.network.api; public interface IPerceptron { - void train(double[][] inputs, double[][] targetOutputs); + void train(double[][] inputs, double[][] targetOutputs); - double[] predict(double[] inputs); + double[] predict(double[] inputs); - void backpropagate(double[] inputs, double target); + void backpropagate(double[] inputs, double target); - double evaluate(double[][] inputs, double[][] targetOutputs); + double evaluate(double[][] inputs, double[][] targetOutputs); } diff --git a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java index e4c29b2..0480572 100644 --- a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java +++ b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java @@ -2,188 +2,190 @@ import de.edux.api.Classifier; import de.edux.ml.decisiontree.DecisionTree; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * RandomForest Classifier - * RandomForest is an ensemble learning method, which constructs a multitude of decision trees - * at training time and outputs the class that is the mode of the classes output by - * individual trees, or a mean prediction of the individual trees (regression). - *

- * Note: Training and prediction are performed in a parallel manner using thread pooling. + * RandomForest Classifier RandomForest is an ensemble learning method, which constructs a multitude + * of decision trees at training time and outputs the class that is the mode of the classes output + * by individual trees, or a mean prediction of the individual trees (regression). + * + *

Note: Training and prediction are performed in a parallel manner using thread pooling. * RandomForest handles the training of individual decision trees and their predictions, and - * determines the final prediction by voting (classification) or averaging (regression) the - * outputs of all the decision trees in the forest. RandomForest is particularly well suited - * for multiclass classification and regression on datasets with complex structures. - *

- * Usage example: - *

- * {@code
+ * determines the final prediction by voting (classification) or averaging (regression) the outputs
+ * of all the decision trees in the forest. RandomForest is particularly well suited for multiclass
+ * classification and regression on datasets with complex structures.
+ *
+ * 

Usage example: + * + *

{@code
  * RandomForest forest = new RandomForest();
  * forest.train(numTrees, features, labels, maxDepth, minSamplesSplit, minSamplesLeaf,
  *              maxLeafNodes, numberOfFeatures);
  * double prediction = forest.predict(sampleFeatures);
  * double accuracy = forest.evaluate(testFeatures, testLabels);
- * }
- * 
- *

- * Thread Safety: This class uses concurrent features but may not be entirely thread-safe + * }

+ * + *

Thread Safety: This class uses concurrent features but may not be entirely thread-safe * and should be used with caution in a multithreaded environment. - *

- * Use {@link #train(double[][], double[][])} to train the forest, - * {@link #predict(double[])} to predict a single sample, and {@link #evaluate(double[][], double[][])} - * to evaluate accuracy against a test set. + * + *

Use {@link #train(double[][], double[][])} to train the forest, {@link #predict(double[])} to + * predict a single sample, and {@link #evaluate(double[][], double[][])} to evaluate accuracy + * against a test set. */ public class RandomForest implements Classifier { - private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class); - - private final List trees = new ArrayList<>(); - private final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current(); - private final int numTrees; - private final int maxDepth; - private final int minSamplesSplit; - private final int minSamplesLeaf; - private final int maxLeafNodes; - private final int numberOfFeatures; - - public RandomForest(int numTrees, int maxDepth, - int minSamplesSplit, - int minSamplesLeaf, - int maxLeafNodes, - int numberOfFeatures) { - this.numTrees = numTrees; - this.maxDepth = maxDepth; - this.minSamplesSplit = minSamplesSplit; - this.minSamplesLeaf = minSamplesLeaf; - this.maxLeafNodes = maxLeafNodes; - this.numberOfFeatures = numberOfFeatures; - } - - public boolean train(double[][] features, double[][] labels) { - ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); - - List> futures = new ArrayList<>(); - - for (int i = 0; i < numTrees; i++) { - futures.add(executor.submit(() -> { - Classifier tree = new DecisionTree(maxDepth, minSamplesSplit, minSamplesLeaf, maxLeafNodes); + private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class); + + private final List trees = new ArrayList<>(); + private final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current(); + private final int numTrees; + private final int maxDepth; + private final int minSamplesSplit; + private final int minSamplesLeaf; + private final int maxLeafNodes; + private final int numberOfFeatures; + + public RandomForest( + int numTrees, + int maxDepth, + int minSamplesSplit, + int minSamplesLeaf, + int maxLeafNodes, + int numberOfFeatures) { + this.numTrees = numTrees; + this.maxDepth = maxDepth; + this.minSamplesSplit = minSamplesSplit; + this.minSamplesLeaf = minSamplesLeaf; + this.maxLeafNodes = maxLeafNodes; + this.numberOfFeatures = numberOfFeatures; + } + + public boolean train(double[][] features, double[][] labels) { + ExecutorService executor = + Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + + List> futures = new ArrayList<>(); + + for (int i = 0; i < numTrees; i++) { + futures.add( + executor.submit( + () -> { + Classifier tree = + new DecisionTree(maxDepth, minSamplesSplit, minSamplesLeaf, maxLeafNodes); Sample subsetSample = getRandomSubset(numberOfFeatures, features, labels); tree.train(subsetSample.featureSamples(), subsetSample.labelSamples()); return tree; - })); - } - - for (Future future : futures) { - try { - trees.add(future.get()); - } catch (ExecutionException | InterruptedException e) { - LOG.error("Failed to train a decision tree. Thread: " + - Thread.currentThread().getName(), e); - } - } - executor.shutdown(); - try { - if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { - executor.shutdownNow(); - } - } catch (InterruptedException ex) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); - } - return true; + })); } - private Sample getRandomSubset(int numberOfFeatures, double[][] features, double[][] labels) { - if (numberOfFeatures > features.length) { - throw new IllegalArgumentException("Number of feature must be between 1 and amount of features"); - } - double[][] subFeatures = new double[numberOfFeatures][]; - double[][] subLabels = new double[numberOfFeatures][]; - for (int i = 0; i < numberOfFeatures; i++) { - int randomIndex = threadLocalRandom.nextInt(numberOfFeatures); - subFeatures[i] = features[randomIndex]; - subLabels[i] = labels[randomIndex]; - } - - return new Sample(subFeatures, subLabels); + for (Future future : futures) { + try { + trees.add(future.get()); + } catch (ExecutionException | InterruptedException e) { + LOG.error( + "Failed to train a decision tree. Thread: " + Thread.currentThread().getName(), e); + } } - - - @Override - public double[] predict(double[] feature) { - ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); - List> futures = new ArrayList<>(); - - for (Classifier tree : trees) { - futures.add(executor.submit(() -> tree.predict(feature))); - } - - Map voteMap = new HashMap<>(); - for (Future future : futures) { - try { - double[] prediction = future.get(); - double label = getIndexOfHighestValue(prediction); - voteMap.merge(label, 1L, Long::sum); - } catch (InterruptedException | ExecutionException e) { - LOG.error("Failed to retrieve prediction from future task. Thread: " + - Thread.currentThread().getName(), e); - } - } - - executor.shutdown(); - try { - if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { - executor.shutdownNow(); - } - } catch (InterruptedException ex) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); - } - double predictionLabel = voteMap.entrySet().stream() - .max(Map.Entry.comparingByValue()) - .get() - .getKey(); - - double[] prediction = new double[trees.get(0).predict(feature).length]; - prediction[(int) predictionLabel] = 1; - return prediction; + executor.shutdown(); + try { + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException ex) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); } + return true; + } - @Override - public double evaluate(double[][] features, double[][] labels) { - int correctPredictions = 0; - for (int i = 0; i < features.length; i++) { - double[] predictedLabelProbabilities = predict(features[i]); - double predictedLabel = getIndexOfHighestValue(predictedLabelProbabilities); - double actualLabel = getIndexOfHighestValue(labels[i]); - if (predictedLabel == actualLabel) { - correctPredictions++; - } - } - double accuracy = (double) correctPredictions / features.length; - LOG.info("RandomForest - Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); - return accuracy; + private Sample getRandomSubset(int numberOfFeatures, double[][] features, double[][] labels) { + if (numberOfFeatures > features.length) { + throw new IllegalArgumentException( + "Number of feature must be between 1 and amount of features"); } + double[][] subFeatures = new double[numberOfFeatures][]; + double[][] subLabels = new double[numberOfFeatures][]; + for (int i = 0; i < numberOfFeatures; i++) { + int randomIndex = threadLocalRandom.nextInt(numberOfFeatures); + subFeatures[i] = features[randomIndex]; + subLabels[i] = labels[randomIndex]; + } + + return new Sample(subFeatures, subLabels); + } + @Override + public double[] predict(double[] feature) { + ExecutorService executor = + Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + List> futures = new ArrayList<>(); - private double getIndexOfHighestValue(double[] labels) { - int maxIndex = 0; - double maxValue = labels[0]; + for (Classifier tree : trees) { + futures.add(executor.submit(() -> tree.predict(feature))); + } - for (int i = 1; i < labels.length; i++) { - if (labels[i] > maxValue) { - maxValue = labels[i]; - maxIndex = i; - } - } - return maxIndex; + Map voteMap = new HashMap<>(); + for (Future future : futures) { + try { + double[] prediction = future.get(); + double label = getIndexOfHighestValue(prediction); + voteMap.merge(label, 1L, Long::sum); + } catch (InterruptedException | ExecutionException e) { + LOG.error( + "Failed to retrieve prediction from future task. Thread: " + + Thread.currentThread().getName(), + e); + } } + executor.shutdown(); + try { + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException ex) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + double predictionLabel = + voteMap.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); + + double[] prediction = new double[trees.get(0).predict(feature).length]; + prediction[(int) predictionLabel] = 1; + return prediction; + } + + @Override + public double evaluate(double[][] features, double[][] labels) { + int correctPredictions = 0; + for (int i = 0; i < features.length; i++) { + double[] predictedLabelProbabilities = predict(features[i]); + double predictedLabel = getIndexOfHighestValue(predictedLabelProbabilities); + double actualLabel = getIndexOfHighestValue(labels[i]); + if (predictedLabel == actualLabel) { + correctPredictions++; + } + } + double accuracy = (double) correctPredictions / features.length; + LOG.info("RandomForest - Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); + return accuracy; + } + + private double getIndexOfHighestValue(double[] labels) { + int maxIndex = 0; + double maxValue = labels[0]; + + for (int i = 1; i < labels.length; i++) { + if (labels[i] > maxValue) { + maxValue = labels[i]; + maxIndex = i; + } + } + return maxIndex; + } } diff --git a/lib/src/main/java/de/edux/ml/randomforest/package-info.java b/lib/src/main/java/de/edux/ml/randomforest/package-info.java index 4ded934..9ad688e 100644 --- a/lib/src/main/java/de/edux/ml/randomforest/package-info.java +++ b/lib/src/main/java/de/edux/ml/randomforest/package-info.java @@ -1,4 +1,2 @@ -/** - * Random Forest implementation. - */ +/** Random Forest implementation. */ package de.edux.ml.randomforest; diff --git a/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java b/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java index ef4b1da..7485a27 100644 --- a/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java +++ b/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java @@ -2,9 +2,9 @@ public interface ISupportVectorMachine { - void train(double[][] features, int[] labels); + void train(double[][] features, int[] labels); - int predict(double[] features); + int predict(double[] features); - double evaluate(double[][] features, int[] labels); + double evaluate(double[][] features, int[] labels); } diff --git a/lib/src/main/java/de/edux/ml/svm/SVMKernel.java b/lib/src/main/java/de/edux/ml/svm/SVMKernel.java index bf6fbd1..59a0b3a 100644 --- a/lib/src/main/java/de/edux/ml/svm/SVMKernel.java +++ b/lib/src/main/java/de/edux/ml/svm/SVMKernel.java @@ -1,5 +1,5 @@ package de.edux.ml.svm; public enum SVMKernel { - LINEAR; + LINEAR; } diff --git a/lib/src/main/java/de/edux/ml/svm/SVMModel.java b/lib/src/main/java/de/edux/ml/svm/SVMModel.java index 32ad670..f586a81 100644 --- a/lib/src/main/java/de/edux/ml/svm/SVMModel.java +++ b/lib/src/main/java/de/edux/ml/svm/SVMModel.java @@ -2,46 +2,46 @@ public class SVMModel { - private final SVMKernel kernel; - private double c; - private double[] weights; - private double bias = 0.0; + private final SVMKernel kernel; + private double c; + private double[] weights; + private double bias = 0.0; - public SVMModel(SVMKernel kernel, double c) { - this.kernel = kernel; - this.c = c; - } + public SVMModel(SVMKernel kernel, double c) { + this.kernel = kernel; + this.c = c; + } - public void train(double[][] features, int[] labels) { - int n = features[0].length; - weights = new double[n]; - int iterations = 10000; + public void train(double[][] features, int[] labels) { + int n = features[0].length; + weights = new double[n]; + int iterations = 10000; - for (int iter = 0; iter < iterations; iter++) { - for (int i = 0; i < features.length; i++) { - double[] xi = features[i]; - int target = labels[i]; - double prediction = predict(xi); + for (int iter = 0; iter < iterations; iter++) { + for (int i = 0; i < features.length; i++) { + double[] xi = features[i]; + int target = labels[i]; + double prediction = predict(xi); - if (target * prediction < 1) { - for (int j = 0; j < n; j++) { - weights[j] = weights[j] + c * (target * xi[j] - 2 * (1/iterations) * weights[j]); - } - bias += c * target; - } else { - for (int j = 0; j < n; j++) { - weights[j] = weights[j] - c * 2 * (1/iterations) * weights[j]; - } - } - } + if (target * prediction < 1) { + for (int j = 0; j < n; j++) { + weights[j] = weights[j] + c * (target * xi[j] - 2 * (1 / iterations) * weights[j]); + } + bias += c * target; + } else { + for (int j = 0; j < n; j++) { + weights[j] = weights[j] - c * 2 * (1 / iterations) * weights[j]; + } } + } } + } - public int predict(double[] features) { - double result = bias; - for (int i = 0; i < weights.length; i++) { - result += weights[i] * features[i]; - } - return (result >= 1) ? 1 : -1; + public int predict(double[] features) { + double result = bias; + for (int i = 0; i < weights.length; i++) { + result += weights[i] * features[i]; } + return (result >= 1) ? 1 : -1; + } } diff --git a/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java b/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java index 3fb5ae3..fa25beb 100644 --- a/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java +++ b/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java @@ -1,18 +1,21 @@ package de.edux.ml.svm; import de.edux.api.Classifier; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.*; import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * The {@code SupportVectorMachine} class is an implementation of a Support Vector Machine (SVM) classifier, utilizing the one-vs-one strategy for multi-class classification. - * This SVM implementation accepts a kernel function and trains separate binary classifiers for each pair of classes in the training set, using provided kernel function and regularization parameter C. - * During the prediction, each model in the pair casts a vote and the final predicted class is the one that gets the most votes among all binary classifiers. + * The {@code SupportVectorMachine} class is an implementation of a Support Vector Machine (SVM) + * classifier, utilizing the one-vs-one strategy for multi-class classification. This SVM + * implementation accepts a kernel function and trains separate binary classifiers for each pair of + * classes in the training set, using provided kernel function and regularization parameter C. + * During the prediction, each model in the pair casts a vote and the final predicted class is the + * one that gets the most votes among all binary classifiers. + * + *

Example usage: * - *

Example usage:

*
{@code
  * SVMKernel kernel = ... ;  // Define an appropriate SVM kernel function
  * double c = ... ;  // Define an appropriate regularization parameter
@@ -24,113 +27,121 @@
  * double accuracy = svm.evaluate(testFeatures, testLabels);
  * }
* - *

Note: Label arrays are expected to be in one-hot encoding format and will be internally converted to single label format for training.

+ *

Note: Label arrays are expected to be in one-hot encoding format and will be internally + * converted to single label format for training. * * @see de.edux.api.Classifier */ public class SupportVectorMachine implements Classifier { - private static final Logger LOG = LoggerFactory.getLogger(SupportVectorMachine.class); - private final SVMKernel kernel; - private final double c; - private final Map models; - - /** - * Constructs a new instance of SupportVectorMachine with a specified kernel and regularization parameter. - * - * This constructor initializes a new Support Vector Machine (SVM) for classification tasks. The SVM employs a one-vs-one strategy - * for multi-class classification. Each model pair within the SVM is trained using the provided kernel function and - * the regularization parameter C. - * - * The kernel is crucial for handling non-linearly separable data by defining a new space in which data points are projected. The - * correct choice of a kernel significantly impacts the performance of the SVM. The regularization parameter C controls the trade-off - * between achieving a low training error and a low testing error that is the ability of the SVM to generalize to unseen data. - * - * @param kernel The kernel to be used for the transformation of the input space. This is necessary for achieving an optimal - * separation in a higher-dimensional space when data is not linearly separable in the original space. The kernel - * defines how data points in space are interpreted based on their similarity. - * @param c The regularization parameter that controls the trade-off between allowing training errors and enforcing rigid margins. - * It helps to prevent overfitting by controlling the strength of the penalty for errors. A higher value of C tries to - * minimize the classification error, potentially at the expense of simplicity, while a lower value of C prioritizes - * simplicity, potentially allowing some misclassifications. - */ - public SupportVectorMachine(SVMKernel kernel, double c) { - this.kernel = kernel; - this.c = c; - this.models = new HashMap<>(); - } - - @Override - public boolean train(double[][] features, double[][] labels) { - var oneDLabels = convert2DLabelArrayTo1DLabelArray(labels); - Set uniqueLabels = Arrays.stream(oneDLabels).boxed().collect(Collectors.toSet()); - Integer[] uniqueLabelsArray = uniqueLabels.toArray(new Integer[0]); - - for (int i = 0; i < uniqueLabelsArray.length; i++) { - for (int j = i + 1; j < uniqueLabelsArray.length; j++) { - String key = uniqueLabelsArray[i] + "-" + uniqueLabelsArray[j]; - SVMModel model = new SVMModel(kernel, c); - - List list = new ArrayList<>(); - List pairLabelsList = new ArrayList<>(); - for (int k = 0; k < features.length; k++) { - if (oneDLabels[k] == uniqueLabelsArray[i] || oneDLabels[k] == uniqueLabelsArray[j]) { - list.add(features[k]); - pairLabelsList.add(oneDLabels[k] == uniqueLabelsArray[i] ? 1 : -1); - } - } - double[][] pairFeatures = list.toArray(new double[0][]); - int[] pairLabels = pairLabelsList.stream().mapToInt(Integer::intValue).toArray(); - - model.train(pairFeatures, pairLabels); - models.put(key, model); - } + private static final Logger LOG = LoggerFactory.getLogger(SupportVectorMachine.class); + private final SVMKernel kernel; + private final double c; + private final Map models; + + /** + * Constructs a new instance of SupportVectorMachine with a specified kernel and regularization + * parameter. + * + *

This constructor initializes a new Support Vector Machine (SVM) for classification tasks. + * The SVM employs a one-vs-one strategy for multi-class classification. Each model pair within + * the SVM is trained using the provided kernel function and the regularization parameter C. + * + *

The kernel is crucial for handling non-linearly separable data by defining a new space in + * which data points are projected. The correct choice of a kernel significantly impacts the + * performance of the SVM. The regularization parameter C controls the trade-off between achieving + * a low training error and a low testing error that is the ability of the SVM to generalize to + * unseen data. + * + * @param kernel The kernel to be used for the transformation of the input space. This is + * necessary for achieving an optimal separation in a higher-dimensional space when data is + * not linearly separable in the original space. The kernel defines how data points in space + * are interpreted based on their similarity. + * @param c The regularization parameter that controls the trade-off between allowing training + * errors and enforcing rigid margins. It helps to prevent overfitting by controlling the + * strength of the penalty for errors. A higher value of C tries to minimize the + * classification error, potentially at the expense of simplicity, while a lower value of C + * prioritizes simplicity, potentially allowing some misclassifications. + */ + public SupportVectorMachine(SVMKernel kernel, double c) { + this.kernel = kernel; + this.c = c; + this.models = new HashMap<>(); + } + + @Override + public boolean train(double[][] features, double[][] labels) { + var oneDLabels = convert2DLabelArrayTo1DLabelArray(labels); + Set uniqueLabels = Arrays.stream(oneDLabels).boxed().collect(Collectors.toSet()); + Integer[] uniqueLabelsArray = uniqueLabels.toArray(new Integer[0]); + + for (int i = 0; i < uniqueLabelsArray.length; i++) { + for (int j = i + 1; j < uniqueLabelsArray.length; j++) { + String key = uniqueLabelsArray[i] + "-" + uniqueLabelsArray[j]; + SVMModel model = new SVMModel(kernel, c); + + List list = new ArrayList<>(); + List pairLabelsList = new ArrayList<>(); + for (int k = 0; k < features.length; k++) { + if (oneDLabels[k] == uniqueLabelsArray[i] || oneDLabels[k] == uniqueLabelsArray[j]) { + list.add(features[k]); + pairLabelsList.add(oneDLabels[k] == uniqueLabelsArray[i] ? 1 : -1); + } } - return true; + double[][] pairFeatures = list.toArray(new double[0][]); + int[] pairLabels = pairLabelsList.stream().mapToInt(Integer::intValue).toArray(); + + model.train(pairFeatures, pairLabels); + models.put(key, model); + } } + return true; + } - @Override - public double[] predict(double[] features) { - Map voteCount = new HashMap<>(); + @Override + public double[] predict(double[] features) { + Map voteCount = new HashMap<>(); - for (Map.Entry entry : models.entrySet()) { - int prediction = entry.getValue().predict(features); + for (Map.Entry entry : models.entrySet()) { + int prediction = entry.getValue().predict(features); - String[] classes = entry.getKey().split("-"); - int classLabel = (prediction == 1) ? Integer.parseInt(classes[0]) : Integer.parseInt(classes[1]); + String[] classes = entry.getKey().split("-"); + int classLabel = + (prediction == 1) ? Integer.parseInt(classes[0]) : Integer.parseInt(classes[1]); - voteCount.put(classLabel, voteCount.getOrDefault(classLabel, 0) + 1); - } + voteCount.put(classLabel, voteCount.getOrDefault(classLabel, 0) + 1); + } - int prediction = voteCount.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); - double[] result = new double[models.size()]; - result[prediction - 1] = 1; - return result; + int prediction = voteCount.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); + double[] result = new double[models.size()]; + result[prediction - 1] = 1; + return result; + } + + @Override + public double evaluate(double[][] features, double[][] labels) { + int correct = 0; + for (int i = 0; i < features.length; i++) { + boolean match = Arrays.equals(predict(features[i]), labels[i]); + if (match) { + correct++; + } } - @Override - public double evaluate(double[][] features, double[][] labels) { - int correct = 0; - for (int i = 0; i < features.length; i++) { - boolean match = Arrays.equals(predict(features[i]), labels[i]); - if (match) { - correct++; - } - } - double accuracy = (double) correct / features.length; + double accuracy = (double) correct / features.length; - LOG.info("SVM - Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); - return accuracy; - } - private int[] convert2DLabelArrayTo1DLabelArray(double[][] labels) { - int[] decisionTreeTrainLabels = new int[labels.length]; - for (int i = 0; i < labels.length; i++) { - for (int j = 0; j < labels[i].length; j++) { - if (labels[i][j] == 1) { - decisionTreeTrainLabels[i] = (j+1); - } - } + LOG.info("SVM - Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); + return accuracy; + } + + private int[] convert2DLabelArrayTo1DLabelArray(double[][] labels) { + int[] decisionTreeTrainLabels = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + for (int j = 0; j < labels[i].length; j++) { + if (labels[i][j] == 1) { + decisionTreeTrainLabels[i] = (j + 1); } - return decisionTreeTrainLabels; + } } + return decisionTreeTrainLabels; + } } - diff --git a/lib/src/main/java/de/edux/ml/svm/package-info.java b/lib/src/main/java/de/edux/ml/svm/package-info.java index 73d779b..564a770 100644 --- a/lib/src/main/java/de/edux/ml/svm/package-info.java +++ b/lib/src/main/java/de/edux/ml/svm/package-info.java @@ -1,4 +1,2 @@ -/** - * Support Vector Machine (SVM) implementation. - */ +/** Support Vector Machine (SVM) implementation. */ package de.edux.ml.svm; diff --git a/lib/src/main/java/de/edux/util/LabelDimensionConverter.java b/lib/src/main/java/de/edux/util/LabelDimensionConverter.java index 0b61598..67005da 100644 --- a/lib/src/main/java/de/edux/util/LabelDimensionConverter.java +++ b/lib/src/main/java/de/edux/util/LabelDimensionConverter.java @@ -1,15 +1,15 @@ package de.edux.util; public class LabelDimensionConverter { - public static int[] convert2DLabelArrayTo1DLabelArray(double[][] labels) { - int[] decisionTreeTrainLabels = new int[labels.length]; - for (int i = 0; i < labels.length; i++) { - for (int j = 0; j < labels[i].length; j++) { - if (labels[i][j] == 1) { - decisionTreeTrainLabels[i] = (j+1); - } - } + public static int[] convert2DLabelArrayTo1DLabelArray(double[][] labels) { + int[] decisionTreeTrainLabels = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + for (int j = 0; j < labels[i].length; j++) { + if (labels[i][j] == 1) { + decisionTreeTrainLabels[i] = (j + 1); } - return decisionTreeTrainLabels; + } } + return decisionTreeTrainLabels; + } } diff --git a/lib/src/main/java/de/edux/util/math/ConcurrentMatrixMultiplication.java b/lib/src/main/java/de/edux/util/math/ConcurrentMatrixMultiplication.java index 05a255e..7737419 100644 --- a/lib/src/main/java/de/edux/util/math/ConcurrentMatrixMultiplication.java +++ b/lib/src/main/java/de/edux/util/math/ConcurrentMatrixMultiplication.java @@ -2,16 +2,15 @@ public interface ConcurrentMatrixMultiplication { - /** - * Multiplies two matrices and returns the resulting matrix. - * - * @param a The first matrix. - * @param b The second matrix. - * @return The product of the two matrices. - * @throws IllegalArgumentException If the matrices cannot be multiplied due to incompatible dimensions. - */ - double[][] multiplyMatrices(double[][] a, double[][] b) throws IllegalArgumentException, IncompatibleDimensionsException; - - - + /** + * Multiplies two matrices and returns the resulting matrix. + * + * @param a The first matrix. + * @param b The second matrix. + * @return The product of the two matrices. + * @throws IllegalArgumentException If the matrices cannot be multiplied due to incompatible + * dimensions. + */ + double[][] multiplyMatrices(double[][] a, double[][] b) + throws IllegalArgumentException, IncompatibleDimensionsException; } diff --git a/lib/src/main/java/de/edux/util/math/IncompatibleDimensionsException.java b/lib/src/main/java/de/edux/util/math/IncompatibleDimensionsException.java index e236e74..471d575 100644 --- a/lib/src/main/java/de/edux/util/math/IncompatibleDimensionsException.java +++ b/lib/src/main/java/de/edux/util/math/IncompatibleDimensionsException.java @@ -1,7 +1,7 @@ package de.edux.util.math; -public class IncompatibleDimensionsException extends Exception{ - public IncompatibleDimensionsException(String message) { - super(message); - } +public class IncompatibleDimensionsException extends Exception { + public IncompatibleDimensionsException(String message) { + super(message); + } } diff --git a/lib/src/main/java/de/edux/util/math/MathMatrix.java b/lib/src/main/java/de/edux/util/math/MathMatrix.java index ccace34..ecd45e3 100644 --- a/lib/src/main/java/de/edux/util/math/MathMatrix.java +++ b/lib/src/main/java/de/edux/util/math/MathMatrix.java @@ -1,58 +1,67 @@ package de.edux.util.math; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class MathMatrix implements ConcurrentMatrixMultiplication { - private static final Logger LOG = LoggerFactory.getLogger(MathMatrix.class); - - @Override - public double[][] multiplyMatrices(double[][] a, double[][] b) throws IncompatibleDimensionsException { - LOG.info("Multiplying matrices of size {}x{} and {}x{}", a.length, a[0].length, b.length, b[0].length); - int aRows = a.length; - int aCols = a[0].length; - int bCols = b[0].length; - - if (aCols != b.length) { - throw new IncompatibleDimensionsException("Cannot multiply matrices with incompatible dimensions"); - } - - double[][] result = new double[aRows][bCols]; - - try(var executor = Executors.newVirtualThreadPerTaskExecutor()) { - List> futures = new ArrayList<>(aRows); - - for (int i = 0; i < aRows; i++) { - final int rowIndex = i; - futures.add(executor.submit(() -> { - for (int colIndex = 0; colIndex < bCols; colIndex++) { - result[rowIndex][colIndex] = multiplyMatrixRowByColumn(a, b, rowIndex, colIndex); - } - return null; + private static final Logger LOG = LoggerFactory.getLogger(MathMatrix.class); + + @Override + public double[][] multiplyMatrices(double[][] a, double[][] b) + throws IncompatibleDimensionsException { + LOG.info( + "Multiplying matrices of size {}x{} and {}x{}", + a.length, + a[0].length, + b.length, + b[0].length); + int aRows = a.length; + int aCols = a[0].length; + int bCols = b[0].length; + + if (aCols != b.length) { + throw new IncompatibleDimensionsException( + "Cannot multiply matrices with incompatible dimensions"); + } + + double[][] result = new double[aRows][bCols]; + + try (var executor = Executors.newVirtualThreadPerTaskExecutor()) { + List> futures = new ArrayList<>(aRows); + + for (int i = 0; i < aRows; i++) { + final int rowIndex = i; + futures.add( + executor.submit( + () -> { + for (int colIndex = 0; colIndex < bCols; colIndex++) { + result[rowIndex][colIndex] = + multiplyMatrixRowByColumn(a, b, rowIndex, colIndex); + } + return null; })); - } - for (var future : futures) { - future.get(); - } - } catch (ExecutionException | InterruptedException e) { - LOG.error("Error while multiplying matrices", e); - } - - LOG.info("Finished multiplying matrices"); - return result; + } + for (var future : futures) { + future.get(); + } + } catch (ExecutionException | InterruptedException e) { + LOG.error("Error while multiplying matrices", e); } - private double multiplyMatrixRowByColumn(double[][] a, double[][] b, int row, int col) { - double sum = 0; - for (int i = 0; i < a[0].length; i++) { - sum += a[row][i] * b[i][col]; - } - return sum; + LOG.info("Finished multiplying matrices"); + return result; + } + + private double multiplyMatrixRowByColumn(double[][] a, double[][] b, int row, int col) { + double sum = 0; + for (int i = 0; i < a[0].length; i++) { + sum += a[row][i] * b[i][col]; } + return sum; + } } diff --git a/lib/src/main/java/de/edux/util/math/MatrixOperations.java b/lib/src/main/java/de/edux/util/math/MatrixOperations.java index af9bf6c..61fcd8e 100644 --- a/lib/src/main/java/de/edux/util/math/MatrixOperations.java +++ b/lib/src/main/java/de/edux/util/math/MatrixOperations.java @@ -1,49 +1,49 @@ package de.edux.util.math; public interface MatrixOperations { - /** - * Adds two matrices and returns the resulting matrix. - * - * @param a The first matrix. - * @param b The second matrix. - * @return The sum of the two matrices. - * @throws IllegalArgumentException If the matrices are not of the same dimension. - */ - double[][] addMatrices(double[][] a, double[][] b) throws IllegalArgumentException; + /** + * Adds two matrices and returns the resulting matrix. + * + * @param a The first matrix. + * @param b The second matrix. + * @return The sum of the two matrices. + * @throws IllegalArgumentException If the matrices are not of the same dimension. + */ + double[][] addMatrices(double[][] a, double[][] b) throws IllegalArgumentException; - /** - * Subtracts matrix b from matrix a and returns the resulting matrix. - * - * @param a The first matrix. - * @param b The second matrix. - * @return The result of a - b. - * @throws IllegalArgumentException If the matrices are not of the same dimension. - */ - double[][] subtractMatrices(double[][] a, double[][] b) throws IllegalArgumentException; + /** + * Subtracts matrix b from matrix a and returns the resulting matrix. + * + * @param a The first matrix. + * @param b The second matrix. + * @return The result of a - b. + * @throws IllegalArgumentException If the matrices are not of the same dimension. + */ + double[][] subtractMatrices(double[][] a, double[][] b) throws IllegalArgumentException; - /** - * Transposes the given matrix and returns the resulting matrix. - * - * @param a The matrix to transpose. - * @return The transposed matrix. - */ - double[][] transposeMatrix(double[][] a); + /** + * Transposes the given matrix and returns the resulting matrix. + * + * @param a The matrix to transpose. + * @return The transposed matrix. + */ + double[][] transposeMatrix(double[][] a); - /** - * Inverts the given matrix and returns the resulting matrix. - * - * @param a The matrix to invert. - * @return The inverted matrix. - * @throws IllegalArgumentException If the matrix is not invertible. - */ - double[][] invertMatrix(double[][] a) throws IllegalArgumentException; + /** + * Inverts the given matrix and returns the resulting matrix. + * + * @param a The matrix to invert. + * @return The inverted matrix. + * @throws IllegalArgumentException If the matrix is not invertible. + */ + double[][] invertMatrix(double[][] a) throws IllegalArgumentException; - /** - * Calculates and returns the determinant of the given matrix. - * - * @param a The matrix. - * @return The determinant of the matrix. - * @throws IllegalArgumentException If the matrix is not square. - */ - double determinant(double[][] a) throws IllegalArgumentException; + /** + * Calculates and returns the determinant of the given matrix. + * + * @param a The matrix. + * @return The determinant of the matrix. + * @throws IllegalArgumentException If the matrix is not square. + */ + double determinant(double[][] a) throws IllegalArgumentException; } diff --git a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java index 8a2fb02..6df2b9b 100644 --- a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java +++ b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java @@ -19,172 +19,196 @@ @ExtendWith(MockitoExtension.class) class DataProcessorTest { - private static final boolean SKIP_HEAD = true; - @Mock - IDataReader dataReader; - List dummyDatasetForImputationTest; - private List dummyDataset; - private DataProcessor dataProcessor; - - @BeforeEach - void setUp() { - dummyDataset = new ArrayList<>(); - dummyDataset.add(new String[]{"col1", "col2", "Name", "col4", "col5"}); - dummyDataset.add(new String[]{"1", "2", "3", "Anna", "5"}); - dummyDataset.add(new String[]{"6", "7", "8", "Nina", "10"}); - dummyDataset.add(new String[]{"11", "12", "13", "Johanna", "15"}); - dummyDataset.add(new String[]{"16", "17", "18", "Isabela", "20"}); - when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset); - - dummyDatasetForImputationTest = new ArrayList<>(); - dummyDatasetForImputationTest.add(new String[] {"Fruit", "Quantity", "Price"}); - dummyDatasetForImputationTest.add(new String[] {"Apple", "", "8"}); - dummyDatasetForImputationTest.add(new String[] {"Apple", "2", "9"}); - dummyDatasetForImputationTest.add(new String[] {"", "3", "10"}); - dummyDatasetForImputationTest.add(new String[] {"Peach", "3", ""}); - dummyDatasetForImputationTest.add(new String[] {"Kiwi", "5", ""}); - dummyDatasetForImputationTest.add(new String[] {"", "3", "11"}); - dummyDatasetForImputationTest.add(new String[] {"Banana", "7", "12"}); - - dataProcessor = new DataProcessor(dataReader); - } - - @Test - void shouldSkipHead() { - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3); - assertEquals(4, dataProcessor.getDataset().size(), "Number of rows does not match."); - } + private static final boolean SKIP_HEAD = true; + @Mock IDataReader dataReader; + List dummyDatasetForImputationTest; + private List dummyDataset; + private DataProcessor dataProcessor; + + @BeforeEach + void setUp() { + dummyDataset = new ArrayList<>(); + dummyDataset.add(new String[] {"col1", "col2", "Name", "col4", "col5"}); + dummyDataset.add(new String[] {"1", "2", "3", "Anna", "5"}); + dummyDataset.add(new String[] {"6", "7", "8", "Nina", "10"}); + dummyDataset.add(new String[] {"11", "12", "13", "Johanna", "15"}); + dummyDataset.add(new String[] {"16", "17", "18", "Isabela", "20"}); + when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset); + + dummyDatasetForImputationTest = new ArrayList<>(); + dummyDatasetForImputationTest.add(new String[] {"Fruit", "Quantity", "Price"}); + dummyDatasetForImputationTest.add(new String[] {"Apple", "", "8"}); + dummyDatasetForImputationTest.add(new String[] {"Apple", "2", "9"}); + dummyDatasetForImputationTest.add(new String[] {"", "3", "10"}); + dummyDatasetForImputationTest.add(new String[] {"Peach", "3", ""}); + dummyDatasetForImputationTest.add(new String[] {"Kiwi", "5", ""}); + dummyDatasetForImputationTest.add(new String[] {"", "3", "11"}); + dummyDatasetForImputationTest.add(new String[] {"Banana", "7", "12"}); + + dataProcessor = new DataProcessor(dataReader); + } - @Test - void shouldNotSkipHead() { - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[]{0, 1, 2, 4}, 3); - assertEquals(5, dataProcessor.getDataset().size(), "Number of rows does not match."); - } + @Test + void shouldSkipHead() { + dataProcessor.loadDataSetFromCSV( + new File("mockpathhere"), ',', SKIP_HEAD, new int[] {0, 1, 2, 4}, 3); + assertEquals(4, dataProcessor.getDataset().size(), "Number of rows does not match."); + } + @Test + void shouldNotSkipHead() { + dataProcessor.loadDataSetFromCSV( + new File("mockpathhere"), ',', false, new int[] {0, 1, 2, 4}, 3); + assertEquals(5, dataProcessor.getDataset().size(), "Number of rows does not match."); + } - @Test - void getTargets() { - when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset); - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3) - .split(0.5); - dummyDataset.add(new String[]{"21", "22", "23", "Isabela", "25"}); - - double[][] targets = dataProcessor.getTargets(dummyDataset, 3); - double[][] expectedTargets = { - {1.0, 0.0, 0.0, 0.0}, // Anna - {0.0, 1.0, 0.0, 0.0}, // Nina - {0.0, 0.0, 1.0, 0.0}, // Johanna - {0.0, 0.0, 0.0, 1.0}, // Isabela - {0.0, 0.0, 0.0, 1.0} // Isabela - }; - - for (int i = 0; i < expectedTargets.length; i++) { - assertArrayEquals(expectedTargets[i], targets[i], "Die Zielzeile " + i + " stimmt nicht überein."); - } - - Map classMap = dataProcessor.getClassMap(); - Map expectedClassMap = Map.of( - "Anna", 0, - "Nina", 1, - "Johanna", 2, - "Isabela", 3); - - assertEquals(expectedClassMap, classMap, "Die Klassen stimmen nicht überein."); + @Test + void getTargets() { + when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset); + dataProcessor + .loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[] {0, 1, 2, 4}, 3) + .split(0.5); + dummyDataset.add(new String[] {"21", "22", "23", "Isabela", "25"}); + + double[][] targets = dataProcessor.getTargets(dummyDataset, 3); + double[][] expectedTargets = { + {1.0, 0.0, 0.0, 0.0}, // Anna + {0.0, 1.0, 0.0, 0.0}, // Nina + {0.0, 0.0, 1.0, 0.0}, // Johanna + {0.0, 0.0, 0.0, 1.0}, // Isabela + {0.0, 0.0, 0.0, 1.0} // Isabela + }; + + for (int i = 0; i < expectedTargets.length; i++) { + assertArrayEquals( + expectedTargets[i], targets[i], "Die Zielzeile " + i + " stimmt nicht überein."); } - @Test - void getInputs() { - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3) - .split(0.5); - double[][] inputs = dataProcessor.getInputs(dummyDataset, new int[]{0, 1, 2, 4}); - - double[][] expectedInputs = { - {1.0, 2.0, 3.0, 5.0}, - {6.0, 7.0, 8.0, 10.0}, - {11.0, 12.0, 13.0, 15.0}, - {16.0, 17.0, 18.0, 20.0} - }; - - assertEquals(expectedInputs.length, inputs.length, "Die Anzahl der Zeilen stimmt nicht überein."); + Map classMap = dataProcessor.getClassMap(); + Map expectedClassMap = + Map.of( + "Anna", 0, + "Nina", 1, + "Johanna", 2, + "Isabela", 3); - for (int i = 0; i < expectedInputs.length; i++) { - assertArrayEquals(expectedInputs[i], inputs[i], "Die Zeile " + i + " entspricht nicht den erwarteten Werten."); - } - } + assertEquals(expectedClassMap, classMap, "Die Klassen stimmen nicht überein."); + } - private List duplicateList(List list) { - List duplicate = new ArrayList<>(); - for (String[] row : list) { - duplicate.add(row.clone()); - } - return duplicate; + @Test + void getInputs() { + dataProcessor + .loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[] {0, 1, 2, 4}, 3) + .split(0.5); + double[][] inputs = dataProcessor.getInputs(dummyDataset, new int[] {0, 1, 2, 4}); + + double[][] expectedInputs = { + {1.0, 2.0, 3.0, 5.0}, + {6.0, 7.0, 8.0, 10.0}, + {11.0, 12.0, 13.0, 15.0}, + {16.0, 17.0, 18.0, 20.0} + }; + + assertEquals( + expectedInputs.length, inputs.length, "Die Anzahl der Zeilen stimmt nicht überein."); + + for (int i = 0; i < expectedInputs.length; i++) { + assertArrayEquals( + expectedInputs[i], + inputs[i], + "Die Zeile " + i + " entspricht nicht den erwarteten Werten."); } + } - @Test - void shouldNormalize() { - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3) - .split(0.5); - List normalizedDataset = dataProcessor.normalize().getDataset(); - - String[][] expectedNormalizedValues = { - {"0.0", "0.0", "0.0", "Anna", "0.0"}, - {"0.3333333333333333", "0.3333333333333333", "0.3333333333333333", "Nina", "0.3333333333333333"}, - {"0.6666666666666666", "0.6666666666666666", "0.6666666666666666", "Johanna", "0.6666666666666666"}, - {"1.0", "1.0", "1.0", "Isabela", "1.0"} - }; - - for (int i = 1; i < normalizedDataset.size(); i++) { - String[] row = normalizedDataset.get(i); - assertArrayEquals(expectedNormalizedValues[i], row, "Die Zeile " + i + " entspricht nicht den erwarteten normalisierten Werten."); - } + private List duplicateList(List list) { + List duplicate = new ArrayList<>(); + for (String[] row : list) { + duplicate.add(row.clone()); } + return duplicate; + } - @Test - void shouldShuffle() { - List originalDataset = duplicateList(dummyDataset); - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[]{0, 1, 2, 4}, 3) - .split(0.5); - List shuffledDataset = dataProcessor.shuffle().getDataset(); - - assertNotEquals(originalDataset, shuffledDataset, "Die Reihenfolge der Zeilen hat sich nicht geändert."); + @Test + void shouldNormalize() { + dataProcessor + .loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[] {0, 1, 2, 4}, 3) + .split(0.5); + List normalizedDataset = dataProcessor.normalize().getDataset(); + + String[][] expectedNormalizedValues = { + {"0.0", "0.0", "0.0", "Anna", "0.0"}, + { + "0.3333333333333333", + "0.3333333333333333", + "0.3333333333333333", + "Nina", + "0.3333333333333333" + }, + { + "0.6666666666666666", + "0.6666666666666666", + "0.6666666666666666", + "Johanna", + "0.6666666666666666" + }, + {"1.0", "1.0", "1.0", "Isabela", "1.0"} + }; + + for (int i = 1; i < normalizedDataset.size(); i++) { + String[] row = normalizedDataset.get(i); + assertArrayEquals( + expectedNormalizedValues[i], + row, + "Die Zeile " + i + " entspricht nicht den erwarteten normalisierten Werten."); } + } + @Test + void shouldShuffle() { + List originalDataset = duplicateList(dummyDataset); + dataProcessor + .loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[] {0, 1, 2, 4}, 3) + .split(0.5); + List shuffledDataset = dataProcessor.shuffle().getDataset(); - @Test - void shouldReturnTrainTestDataset() { - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[]{0, 1, 2, 4}, 3); - dataProcessor.split(0.5); + assertNotEquals( + originalDataset, shuffledDataset, "Die Reihenfolge der Zeilen hat sich nicht geändert."); + } - int[] inputColumns = new int[]{0, 1, 2, 4}; - double[][] trainFeatures = dataProcessor.getTrainFeatures(inputColumns); - double[][] testFeatures = dataProcessor.getTestFeatures(inputColumns); + @Test + void shouldReturnTrainTestDataset() { + dataProcessor.loadDataSetFromCSV( + new File("mockpathhere"), ',', false, new int[] {0, 1, 2, 4}, 3); + dataProcessor.split(0.5); - double[][] trainLabels = dataProcessor.getTrainLabels(3); - double[][] testLabels = dataProcessor.getTestLabels(3); + int[] inputColumns = new int[] {0, 1, 2, 4}; + double[][] trainFeatures = dataProcessor.getTrainFeatures(inputColumns); + double[][] testFeatures = dataProcessor.getTestFeatures(inputColumns); - } + double[][] trainLabels = dataProcessor.getTrainLabels(3); + double[][] testLabels = dataProcessor.getTestLabels(3); + } - @Test - void shouldPerformImputationOnDataset() { - when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDatasetForImputationTest); - dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1}, 2); - - ImputationStrategy modeImputter = ImputationStrategy.MODE; - ImputationStrategy averageImputter = ImputationStrategy.AVERAGE; - - dataProcessor.imputation("Fruit",modeImputter); - dataProcessor.imputation("Quantity",modeImputter); - dataProcessor.imputation("Price",averageImputter); - var imputtedDataset = dataProcessor.getDataset(); - - assertAll( - () -> assertArrayEquals(new String[] {"Apple", "3", "8"}, imputtedDataset.get(0)), - () -> assertArrayEquals(new String[] {"Apple", "2", "9"}, imputtedDataset.get(1)), - () -> assertArrayEquals(new String[] {"Apple", "3", "10"}, imputtedDataset.get(2)), - () -> assertArrayEquals(new String[] {"Peach", "3", "10.0"}, imputtedDataset.get(3)), - () -> assertArrayEquals(new String[] {"Kiwi", "5", "10.0"}, imputtedDataset.get(4)), - () -> assertArrayEquals(new String[] {"Apple", "3", "11"}, imputtedDataset.get(5)), - () -> assertArrayEquals(new String[] {"Banana", "7", "12"}, imputtedDataset.get(6)) - ); + @Test + void shouldPerformImputationOnDataset() { + when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDatasetForImputationTest); + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[] {0, 1}, 2); + + ImputationStrategy modeImputter = ImputationStrategy.MODE; + ImputationStrategy averageImputter = ImputationStrategy.AVERAGE; + + dataProcessor.imputation("Fruit", modeImputter); + dataProcessor.imputation("Quantity", modeImputter); + dataProcessor.imputation("Price", averageImputter); + var imputtedDataset = dataProcessor.getDataset(); + + assertAll( + () -> assertArrayEquals(new String[] {"Apple", "3", "8"}, imputtedDataset.get(0)), + () -> assertArrayEquals(new String[] {"Apple", "2", "9"}, imputtedDataset.get(1)), + () -> assertArrayEquals(new String[] {"Apple", "3", "10"}, imputtedDataset.get(2)), + () -> assertArrayEquals(new String[] {"Peach", "3", "10.0"}, imputtedDataset.get(3)), + () -> assertArrayEquals(new String[] {"Kiwi", "5", "10.0"}, imputtedDataset.get(4)), + () -> assertArrayEquals(new String[] {"Apple", "3", "11"}, imputtedDataset.get(5)), + () -> assertArrayEquals(new String[] {"Banana", "7", "12"}, imputtedDataset.get(6))); } -} \ No newline at end of file +} diff --git a/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java b/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java index deb45b3..51afb10 100644 --- a/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java +++ b/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java @@ -1,83 +1,88 @@ package de.edux.edux.activation; +import static org.junit.jupiter.api.Assertions.assertEquals; + import de.edux.functions.activation.ActivationFunction; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; - public class ActivationFunctionTest { - private static final double DELTA = 1e-6; - - @Test - public void testSigmoid() { - double input = 0.5; - double expectedActivation = 1 / (1 + Math.exp(-input)); - double expectedDerivative = expectedActivation * (1 - expectedActivation); - - Assertions.assertEquals(expectedActivation, ActivationFunction.SIGMOID.calculateActivation(input), DELTA); - assertEquals(expectedDerivative, ActivationFunction.SIGMOID.calculateDerivative(input), DELTA); + private static final double DELTA = 1e-6; + + @Test + public void testSigmoid() { + double input = 0.5; + double expectedActivation = 1 / (1 + Math.exp(-input)); + double expectedDerivative = expectedActivation * (1 - expectedActivation); + + Assertions.assertEquals( + expectedActivation, ActivationFunction.SIGMOID.calculateActivation(input), DELTA); + assertEquals(expectedDerivative, ActivationFunction.SIGMOID.calculateDerivative(input), DELTA); + } + + @Test + public void testRelu() { + double inputPositive = 0.5; + double inputNegative = -0.5; + + assertEquals(inputPositive, ActivationFunction.RELU.calculateActivation(inputPositive), DELTA); + assertEquals(0.0, ActivationFunction.RELU.calculateActivation(inputNegative), DELTA); + + assertEquals(1.0, ActivationFunction.RELU.calculateDerivative(inputPositive), DELTA); + assertEquals(0.0, ActivationFunction.RELU.calculateDerivative(inputNegative), DELTA); + } + + @Test + public void testLeakyRelu() { + double inputPositive = 0.5; + double inputNegative = -0.5; + + assertEquals( + inputPositive, ActivationFunction.LEAKY_RELU.calculateActivation(inputPositive), DELTA); + assertEquals( + 0.01 * inputNegative, + ActivationFunction.LEAKY_RELU.calculateActivation(inputNegative), + DELTA); + + assertEquals(1.0, ActivationFunction.LEAKY_RELU.calculateDerivative(inputPositive), DELTA); + assertEquals(0.01, ActivationFunction.LEAKY_RELU.calculateDerivative(inputNegative), DELTA); + } + + @Test + public void testTanh() { + double input = 0.5; + double expectedActivation = Math.tanh(input); + double expectedDerivative = 1 - Math.pow(expectedActivation, 2); + + assertEquals(expectedActivation, ActivationFunction.TANH.calculateActivation(input), DELTA); + assertEquals(expectedDerivative, ActivationFunction.TANH.calculateDerivative(input), DELTA); + } + + @Test + public void testSoftmax() { + double input = 0.5; + double expectedActivation = Math.exp(input); + double expectedDerivative = expectedActivation * (1 - expectedActivation); + + assertEquals(expectedActivation, ActivationFunction.SOFTMAX.calculateActivation(input), DELTA); + assertEquals(expectedDerivative, ActivationFunction.SOFTMAX.calculateDerivative(input), DELTA); + + double[] inputs = {0.1, 0.2, 0.3}; + double[] expectedOutputs = new double[inputs.length]; + double sum = 0.0; + for (int i = 0; i < inputs.length; i++) { + expectedOutputs[i] = Math.exp(inputs[i]); + sum += expectedOutputs[i]; } - - @Test - public void testRelu() { - double inputPositive = 0.5; - double inputNegative = -0.5; - - assertEquals(inputPositive, ActivationFunction.RELU.calculateActivation(inputPositive), DELTA); - assertEquals(0.0, ActivationFunction.RELU.calculateActivation(inputNegative), DELTA); - - assertEquals(1.0, ActivationFunction.RELU.calculateDerivative(inputPositive), DELTA); - assertEquals(0.0, ActivationFunction.RELU.calculateDerivative(inputNegative), DELTA); + for (int i = 0; i < expectedOutputs.length; i++) { + expectedOutputs[i] /= sum; } - @Test - public void testLeakyRelu() { - double inputPositive = 0.5; - double inputNegative = -0.5; - - assertEquals(inputPositive, ActivationFunction.LEAKY_RELU.calculateActivation(inputPositive), DELTA); - assertEquals(0.01 * inputNegative, ActivationFunction.LEAKY_RELU.calculateActivation(inputNegative), DELTA); - - assertEquals(1.0, ActivationFunction.LEAKY_RELU.calculateDerivative(inputPositive), DELTA); - assertEquals(0.01, ActivationFunction.LEAKY_RELU.calculateDerivative(inputNegative), DELTA); - } - - @Test - public void testTanh() { - double input = 0.5; - double expectedActivation = Math.tanh(input); - double expectedDerivative = 1 - Math.pow(expectedActivation, 2); - - assertEquals(expectedActivation, ActivationFunction.TANH.calculateActivation(input), DELTA); - assertEquals(expectedDerivative, ActivationFunction.TANH.calculateDerivative(input), DELTA); - } + double[] outputs = ActivationFunction.SOFTMAX.calculateActivation(inputs); - @Test - public void testSoftmax() { - double input = 0.5; - double expectedActivation = Math.exp(input); - double expectedDerivative = expectedActivation * (1 - expectedActivation); - - assertEquals(expectedActivation, ActivationFunction.SOFTMAX.calculateActivation(input), DELTA); - assertEquals(expectedDerivative, ActivationFunction.SOFTMAX.calculateDerivative(input), DELTA); - - double[] inputs = {0.1, 0.2, 0.3}; - double[] expectedOutputs = new double[inputs.length]; - double sum = 0.0; - for (int i = 0; i < inputs.length; i++) { - expectedOutputs[i] = Math.exp(inputs[i]); - sum += expectedOutputs[i]; - } - for (int i = 0; i < expectedOutputs.length; i++) { - expectedOutputs[i] /= sum; - } - - double[] outputs = ActivationFunction.SOFTMAX.calculateActivation(inputs); - - for (int i = 0; i < inputs.length; i++) { - assertEquals(expectedOutputs[i], outputs[i], DELTA); - } + for (int i = 0; i < inputs.length; i++) { + assertEquals(expectedOutputs[i], outputs[i], DELTA); } + } } diff --git a/lib/src/test/java/de/edux/functions/InitializationTest.java b/lib/src/test/java/de/edux/functions/InitializationTest.java index f50424b..e8dd0cc 100644 --- a/lib/src/test/java/de/edux/functions/InitializationTest.java +++ b/lib/src/test/java/de/edux/functions/InitializationTest.java @@ -1,33 +1,36 @@ package de.edux.functions; +import static org.junit.jupiter.api.Assertions.assertTrue; + import de.edux.functions.initialization.Initialization; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertTrue; - public class InitializationTest { - @Test - public void testXavierInitialization() { - int inputSize = 10; - double[] weights = new double[inputSize]; - weights = Initialization.XAVIER.weightInitialization(inputSize, weights); + @Test + public void testXavierInitialization() { + int inputSize = 10; + double[] weights = new double[inputSize]; + weights = Initialization.XAVIER.weightInitialization(inputSize, weights); - double xavier = Math.sqrt(6.0 / (inputSize + 1)); - for (double weight : weights) { - assertTrue(weight >= -xavier && weight <= xavier, "Weight should be in the range of Xavier initialization"); - } + double xavier = Math.sqrt(6.0 / (inputSize + 1)); + for (double weight : weights) { + assertTrue( + weight >= -xavier && weight <= xavier, + "Weight should be in the range of Xavier initialization"); } + } - @Test - public void testHeInitialization() { - int inputSize = 10; - double[] weights = new double[inputSize]; - weights = Initialization.HE.weightInitialization(inputSize, weights); + @Test + public void testHeInitialization() { + int inputSize = 10; + double[] weights = new double[inputSize]; + weights = Initialization.HE.weightInitialization(inputSize, weights); - double he = Math.sqrt(2.0 / inputSize); - for (double weight : weights) { - assertTrue(weight >= -he && weight <= he, "Weight should be in the range of He initialization"); - } + double he = Math.sqrt(2.0 / inputSize); + for (double weight : weights) { + assertTrue( + weight >= -he && weight <= he, "Weight should be in the range of He initialization"); } + } } diff --git a/lib/src/test/java/de/edux/functions/loss/LossFunctionTest.java b/lib/src/test/java/de/edux/functions/loss/LossFunctionTest.java index d415c87..382ac17 100644 --- a/lib/src/test/java/de/edux/functions/loss/LossFunctionTest.java +++ b/lib/src/test/java/de/edux/functions/loss/LossFunctionTest.java @@ -1,64 +1,76 @@ package de.edux.functions.loss; -import org.junit.jupiter.api.Test; - import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; + public class LossFunctionTest { - private static final double DELTA = 1e-6; // used to compare floating point numbers - - @Test - public void testCategoricalCrossEntropy() { - double[] output = {0.1, 0.2, 0.7}; - double[] target = {0, 0, 1}; - - double expectedError = -Math.log(0.7); - assertEquals(expectedError, LossFunction.CATEGORICAL_CROSS_ENTROPY.calculateError(output, target), DELTA); - } - - @Test - public void testMeanSquaredError() { - double[] output = {0.1, 0.2, 0.7}; - double[] target = {0.2, 0.1, 0.8}; - - double expectedError = (Math.pow(0.1, 2) + Math.pow(0.1, 2) + Math.pow(0.1, 2)) / 3; - assertEquals(expectedError, LossFunction.MEAN_SQUARED_ERROR.calculateError(output, target), DELTA); - } - - @Test - public void testMeanAbsoluteError() { - double[] output = {0.1, 0.2, 0.7}; - double[] target = {0.2, 0.1, 0.8}; - - double expectedError = (Math.abs(0.1) + Math.abs(0.1) + Math.abs(0.1)) / 3; - assertEquals(expectedError, LossFunction.MEAN_ABSOLUTE_ERROR.calculateError(output, target), DELTA); - } - - @Test - public void testHingeLoss() { - double[] output = {0.1, 0.2, 0.7}; - double[] target = {-1, -1, 1}; - - double expectedError = (Math.max(0, 1 - (-1) * 0.1) + Math.max(0, 1 - (-1) * 0.2) + Math.max(0, 1 - 1 * 0.7)) / 3; - assertEquals(expectedError, LossFunction.HINGE_LOSS.calculateError(output, target), DELTA); - } - - @Test - public void testSquaredHingeLoss() { - double[] output = {0.1, 0.2, 0.7}; - double[] target = {-1, -1, 1}; - - double expectedError = (Math.pow(Math.max(0, 1 - (-1) * 0.1), 2) + Math.pow(Math.max(0, 1 - (-1) * 0.2), 2) + Math.pow(Math.max(0, 1 - 1 * 0.7), 2)) / 3; - assertEquals(expectedError, LossFunction.SQUARED_HINGE_LOSS.calculateError(output, target), DELTA); - } - - @Test - public void testBinaryCrossEntropy() { - double[] output = {0.1, 0.2, 0.7}; - double[] target = {1, 0, 1}; - - double expectedError = - (Math.log(0.1) + Math.log(1 - 0.2) + Math.log(0.7)); - assertEquals(expectedError, LossFunction.BINARY_CROSS_ENTROPY.calculateError(output, target), DELTA); - } + private static final double DELTA = 1e-6; // used to compare floating point numbers + + @Test + public void testCategoricalCrossEntropy() { + double[] output = {0.1, 0.2, 0.7}; + double[] target = {0, 0, 1}; + + double expectedError = -Math.log(0.7); + assertEquals( + expectedError, + LossFunction.CATEGORICAL_CROSS_ENTROPY.calculateError(output, target), + DELTA); + } + + @Test + public void testMeanSquaredError() { + double[] output = {0.1, 0.2, 0.7}; + double[] target = {0.2, 0.1, 0.8}; + + double expectedError = (Math.pow(0.1, 2) + Math.pow(0.1, 2) + Math.pow(0.1, 2)) / 3; + assertEquals( + expectedError, LossFunction.MEAN_SQUARED_ERROR.calculateError(output, target), DELTA); + } + + @Test + public void testMeanAbsoluteError() { + double[] output = {0.1, 0.2, 0.7}; + double[] target = {0.2, 0.1, 0.8}; + + double expectedError = (Math.abs(0.1) + Math.abs(0.1) + Math.abs(0.1)) / 3; + assertEquals( + expectedError, LossFunction.MEAN_ABSOLUTE_ERROR.calculateError(output, target), DELTA); + } + + @Test + public void testHingeLoss() { + double[] output = {0.1, 0.2, 0.7}; + double[] target = {-1, -1, 1}; + + double expectedError = + (Math.max(0, 1 - (-1) * 0.1) + Math.max(0, 1 - (-1) * 0.2) + Math.max(0, 1 - 1 * 0.7)) / 3; + assertEquals(expectedError, LossFunction.HINGE_LOSS.calculateError(output, target), DELTA); + } + + @Test + public void testSquaredHingeLoss() { + double[] output = {0.1, 0.2, 0.7}; + double[] target = {-1, -1, 1}; + + double expectedError = + (Math.pow(Math.max(0, 1 - (-1) * 0.1), 2) + + Math.pow(Math.max(0, 1 - (-1) * 0.2), 2) + + Math.pow(Math.max(0, 1 - 1 * 0.7), 2)) + / 3; + assertEquals( + expectedError, LossFunction.SQUARED_HINGE_LOSS.calculateError(output, target), DELTA); + } + + @Test + public void testBinaryCrossEntropy() { + double[] output = {0.1, 0.2, 0.7}; + double[] target = {1, 0, 1}; + + double expectedError = -(Math.log(0.1) + Math.log(1 - 0.2) + Math.log(0.7)); + assertEquals( + expectedError, LossFunction.BINARY_CROSS_ENTROPY.calculateError(output, target), DELTA); + } } diff --git a/lib/src/test/java/de/edux/math/entity/MatrixTest.java b/lib/src/test/java/de/edux/math/entity/MatrixTest.java index ed590df..7f2cee3 100644 --- a/lib/src/test/java/de/edux/math/entity/MatrixTest.java +++ b/lib/src/test/java/de/edux/math/entity/MatrixTest.java @@ -1,59 +1,74 @@ package de.edux.math.entity; +import static org.junit.jupiter.api.Assertions.assertEquals; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; - public class MatrixTest { - static Matrix first; - static Matrix second; - - @BeforeEach - public void init() { - first = new Matrix(new double[][] { - {5, 3, -1}, - {-2, 0, 6}, - {5, 1, -9} - }); - second = new Matrix(new double[][] { - {8, 7, 4}, - {1, -5, 2}, - {0, 3, 0} - }); - } - - @Test - public void testAdd() { - assertEquals(new Matrix(new double[][] { - {13, 10, 3}, - {-1, -5, 8}, - {5, 4, -9} - }), first.add(second)); - } - - @Test - public void testSubtract() { - assertEquals(new Matrix(new double[][] { - {-3, -4, -5}, - {-3, 5, 4}, - {5, -2, -9} - }), first.subtract(second)); - } - - @Test - public void testScalarMultiply() { - assertEquals(new Matrix(new double[][] { - {20, 12, -4}, - {-8, 0, 24}, - {20, 4, -36} - }), first.scalarMultiply(4)); - assertEquals(new Matrix(new double[][] { - {-48, -42, -24}, - {-6, 30, -12}, - {0, -18, 0} - }), second.scalarMultiply(-6)); - } + static Matrix first; + static Matrix second; + + @BeforeEach + public void init() { + first = + new Matrix( + new double[][] { + {5, 3, -1}, + {-2, 0, 6}, + {5, 1, -9} + }); + second = + new Matrix( + new double[][] { + {8, 7, 4}, + {1, -5, 2}, + {0, 3, 0} + }); + } + + @Test + public void testAdd() { + assertEquals( + new Matrix( + new double[][] { + {13, 10, 3}, + {-1, -5, 8}, + {5, 4, -9} + }), + first.add(second)); + } + + @Test + public void testSubtract() { + assertEquals( + new Matrix( + new double[][] { + {-3, -4, -5}, + {-3, 5, 4}, + {5, -2, -9} + }), + first.subtract(second)); + } + @Test + public void testScalarMultiply() { + assertEquals( + new Matrix( + new double[][] { + {20, 12, -4}, + {-8, 0, 24}, + {20, 4, -36} + }), + first.scalarMultiply(4)); + assertEquals( + new Matrix( + new double[][] { + {-48, -42, -24}, + {-6, 30, -12}, + {0, -18, 0} + }), + second.scalarMultiply(-6)); + } } diff --git a/lib/src/test/java/de/edux/math/entity/VectorTest.java b/lib/src/test/java/de/edux/math/entity/VectorTest.java index f4fc013..fb924f3 100644 --- a/lib/src/test/java/de/edux/math/entity/VectorTest.java +++ b/lib/src/test/java/de/edux/math/entity/VectorTest.java @@ -1,45 +1,44 @@ package de.edux.math.entity; +import static org.junit.jupiter.api.Assertions.assertEquals; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; - public class VectorTest { - static Vector first; - static Vector second; - - @BeforeAll - public static void init() { - first = new Vector(new double[] {1, 5, 4}); - second = new Vector(new double[] {3, 8, 0}); - } - - @Test - public void testAdd() { - assertEquals(new Vector(new double[] {4, 13, 4}), first.add(second)); - } - - @Test - public void testSubtract() { - assertEquals(new Vector(new double[] {-2, -3, 4}), first.subtract(second)); - } - - @Test - public void testMultiply() { - assertEquals(new Vector(new double[] {3, 40, 0}), first.multiply(second)); - } - - @Test - public void testScalarMultiply() { - assertEquals(new Vector(new double[] {3, 15, 12}), first.scalarMultiply(3)); // first by 3 - assertEquals(new Vector(new double[] {-6, -16, 0}), second.scalarMultiply(-2)); // second by -2 - } - - @Test - public void testDot() { - assertEquals(43, first.dot(second)); - } - + static Vector first; + static Vector second; + + @BeforeAll + public static void init() { + first = new Vector(new double[] {1, 5, 4}); + second = new Vector(new double[] {3, 8, 0}); + } + + @Test + public void testAdd() { + assertEquals(new Vector(new double[] {4, 13, 4}), first.add(second)); + } + + @Test + public void testSubtract() { + assertEquals(new Vector(new double[] {-2, -3, 4}), first.subtract(second)); + } + + @Test + public void testMultiply() { + assertEquals(new Vector(new double[] {3, 40, 0}), first.multiply(second)); + } + + @Test + public void testScalarMultiply() { + assertEquals(new Vector(new double[] {3, 15, 12}), first.scalarMultiply(3)); // first by 3 + assertEquals(new Vector(new double[] {-6, -16, 0}), second.scalarMultiply(-2)); // second by -2 + } + + @Test + public void testDot() { + assertEquals(43, first.dot(second)); + } } diff --git a/lib/src/test/java/de/edux/ml/nn/network/NeuronTest.java b/lib/src/test/java/de/edux/ml/nn/network/NeuronTest.java index b9e6682..8ddbad7 100644 --- a/lib/src/test/java/de/edux/ml/nn/network/NeuronTest.java +++ b/lib/src/test/java/de/edux/ml/nn/network/NeuronTest.java @@ -1,49 +1,50 @@ package de.edux.ml.nn.network; + +import static org.junit.jupiter.api.Assertions.assertEquals; + import de.edux.functions.activation.ActivationFunction; import de.edux.functions.initialization.Initialization; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; - public class NeuronTest { - private Neuron neuron; - private final int inputSize = 3; - private final ActivationFunction dummyActivationFunction = ActivationFunction.SOFTMAX; + private final int inputSize = 3; + private final ActivationFunction dummyActivationFunction = ActivationFunction.SOFTMAX; + private Neuron neuron; + + @BeforeEach + public void setUp() { + neuron = new Neuron(inputSize, dummyActivationFunction, Initialization.XAVIER); + } - @BeforeEach - public void setUp() { - neuron = new Neuron(inputSize, dummyActivationFunction, Initialization.XAVIER); + @Test + public void testAdjustWeights() { + double[] initialWeights = new double[inputSize]; + for (int i = 0; i < inputSize; i++) { + initialWeights[i] = neuron.getWeight(i); } - @Test - public void testAdjustWeights() { - double[] initialWeights = new double[inputSize]; - for (int i = 0; i < inputSize; i++) { - initialWeights[i] = neuron.getWeight(i); - } - - double[] input = {1.0, 2.0, 3.0}; - double error = 0.5; - double learningRate = 0.1; - neuron.adjustWeights(input, error, learningRate); - - for (int i = 0; i < inputSize; i++) { - double expectedWeight = initialWeights[i] + learningRate * input[i] * error; - assertEquals(expectedWeight, neuron.getWeight(i)); - } + double[] input = {1.0, 2.0, 3.0}; + double error = 0.5; + double learningRate = 0.1; + neuron.adjustWeights(input, error, learningRate); + + for (int i = 0; i < inputSize; i++) { + double expectedWeight = initialWeights[i] + learningRate * input[i] * error; + assertEquals(expectedWeight, neuron.getWeight(i)); } + } - @Test - public void testAdjustBias() { - double initialBias = neuron.getBias(); + @Test + public void testAdjustBias() { + double initialBias = neuron.getBias(); - double error = 0.5; - double learningRate = 0.1; - neuron.adjustBias(error, learningRate); + double error = 0.5; + double learningRate = 0.1; + neuron.adjustBias(error, learningRate); - double expectedBias = initialBias + learningRate * error; - assertEquals(expectedBias, neuron.getBias()); - } + double expectedBias = initialBias + learningRate * error; + assertEquals(expectedBias, neuron.getBias()); + } } diff --git a/lib/src/test/java/de/edux/util/math/MathMatrixTest.java b/lib/src/test/java/de/edux/util/math/MathMatrixTest.java index d24fe20..7387f0b 100644 --- a/lib/src/test/java/de/edux/util/math/MathMatrixTest.java +++ b/lib/src/test/java/de/edux/util/math/MathMatrixTest.java @@ -1,8 +1,6 @@ package de.edux.util.math; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.ArrayList; import java.util.List; @@ -10,106 +8,108 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; class MathMatrixTest { - private static final long someMaximumValue = 1_000_000_000; // Example value - private static final Logger LOG = LoggerFactory.getLogger(MathMatrixTest.class); - - @Test - void multiplyMatrices() throws IncompatibleDimensionsException { - long startTime = System.currentTimeMillis(); - int size = 500; - - double[][] matrixA = generateMatrix(size); - double[][] matrixB = generateMatrix(size); - - ConcurrentMatrixMultiplication matrixMultiplier = new MathMatrix(); - double[][] resultMatrix = matrixMultiplier.multiplyMatrices(matrixA, matrixB); - - assertEquals(size, resultMatrix.length); - assertEquals(size, resultMatrix[0].length); + private static final long someMaximumValue = 1_000_000_000; // Example value + private static final Logger LOG = LoggerFactory.getLogger(MathMatrixTest.class); - long endTime = System.currentTimeMillis(); - long timeElapsed = endTime - startTime; - LOG.info("Time elapsed: " + timeElapsed / 1000 + " seconds"); - } - - @Test - void multiplyMatricesSmall() throws IncompatibleDimensionsException { - double[][] matrixA = { - {1, 2}, - {3, 4} - }; - - double[][] matrixB = { - {2, 0}, - {1, 3} - }; - - ConcurrentMatrixMultiplication matrixMultiplier = new MathMatrix(); - double[][] resultMatrix = matrixMultiplier.multiplyMatrices(matrixA, matrixB); + static void assertArrayEquals(double[][] expected, double[][] actual) { + assertEquals(expected.length, actual.length); - double[][] expectedMatrix = { - {4, 6}, - {10, 12} - }; - - assertArrayEquals(expectedMatrix, resultMatrix); + for (int i = 0; i < expected.length; i++) { + assertArrayEquals(expected[i], actual[i]); } + } - static void assertArrayEquals(double[][] expected, double[][] actual) { - assertEquals(expected.length, actual.length); + static void assertArrayEquals(double[] expected, double[] actual) { + assertEquals(expected.length, actual.length); - for (int i = 0; i < expected.length; i++) { - assertArrayEquals(expected[i], actual[i]); - } + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], actual[i]); } - - static void assertArrayEquals(double[] expected, double[] actual) { - assertEquals(expected.length, actual.length); - - for (int i = 0; i < expected.length; i++) { - assertEquals(expected[i], actual[i]); - } - } - - double[][] generateMatrix(int size) { - double[][] matrix = new double[size][size]; - final int MAX_THREADS = 32; - - ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor(); - List> futures = new ArrayList<>(); - - try { - int rowsPerThread = Math.max(size / MAX_THREADS, 1); - - for (int i = 0; i < MAX_THREADS && i * rowsPerThread < size; i++) { - final int startRow = i * rowsPerThread; - final int endRow = Math.min((i + 1) * rowsPerThread, size); - - futures.add(executor.submit(() -> { - for (int row = startRow; row < endRow; row++) { - for (int col = 0; col < size; col++) { - matrix[row][col] = Math.random() * 10; // Random values between 0 and 10 - } + } + + @Test + void multiplyMatrices() throws IncompatibleDimensionsException { + long startTime = System.currentTimeMillis(); + int size = 500; + + double[][] matrixA = generateMatrix(size); + double[][] matrixB = generateMatrix(size); + + ConcurrentMatrixMultiplication matrixMultiplier = new MathMatrix(); + double[][] resultMatrix = matrixMultiplier.multiplyMatrices(matrixA, matrixB); + + assertEquals(size, resultMatrix.length); + assertEquals(size, resultMatrix[0].length); + + long endTime = System.currentTimeMillis(); + long timeElapsed = endTime - startTime; + LOG.info("Time elapsed: " + timeElapsed / 1000 + " seconds"); + } + + @Test + void multiplyMatricesSmall() throws IncompatibleDimensionsException { + double[][] matrixA = { + {1, 2}, + {3, 4} + }; + + double[][] matrixB = { + {2, 0}, + {1, 3} + }; + + ConcurrentMatrixMultiplication matrixMultiplier = new MathMatrix(); + double[][] resultMatrix = matrixMultiplier.multiplyMatrices(matrixA, matrixB); + + double[][] expectedMatrix = { + {4, 6}, + {10, 12} + }; + + assertArrayEquals(expectedMatrix, resultMatrix); + } + + double[][] generateMatrix(int size) { + double[][] matrix = new double[size][size]; + final int MAX_THREADS = 32; + + ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor(); + List> futures = new ArrayList<>(); + + try { + int rowsPerThread = Math.max(size / MAX_THREADS, 1); + + for (int i = 0; i < MAX_THREADS && i * rowsPerThread < size; i++) { + final int startRow = i * rowsPerThread; + final int endRow = Math.min((i + 1) * rowsPerThread, size); + + futures.add( + executor.submit( + () -> { + for (int row = startRow; row < endRow; row++) { + for (int col = 0; col < size; col++) { + matrix[row][col] = Math.random() * 10; // Random values between 0 and 10 } - return null; + } + return null; })); - } - - for (Future future : futures) { - future.get(); - } - } catch (InterruptedException | ExecutionException e) { - e.printStackTrace(); - } finally { - executor.shutdown(); - } - - LOG.info("Generated matrix with size: " + size); - return matrix; + } + + for (Future future : futures) { + future.get(); + } + } catch (InterruptedException | ExecutionException e) { + e.printStackTrace(); + } finally { + executor.shutdown(); } -} \ No newline at end of file + LOG.info("Generated matrix with size: " + size); + return matrix; + } +}