diff --git a/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnPenguinsDatasetNoHead.java b/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnPenguinsDatasetNoHead.java index 678e54f..c004e85 100644 --- a/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnPenguinsDatasetNoHead.java +++ b/example/src/main/java/de/example/decisiontree/DecisionTreeExampleOnPenguinsDatasetNoHead.java @@ -18,7 +18,7 @@ public class DecisionTreeExampleOnPenguinsDatasetNoHead { + "seaborn-penguins" + File.separator + "penguins-no-head.csv"); - private static final boolean SKIP_HEAD = true; + private static final boolean SKIP_HEAD = false; private static final ImputationStrategy averageImputation = ImputationStrategy.AVERAGE; private static final ImputationStrategy modeImputation = ImputationStrategy.MODE; diff --git a/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixArithmetic.java b/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixArithmetic.java index d65e296..b9f3486 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixArithmetic.java +++ b/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixArithmetic.java @@ -18,11 +18,30 @@ public MatrixArithmetic() { @Override public double[][] multiply(double[][] matrixA, double[][] matrixB) { + + if (matrixA == null || matrixB == null) { + throw new IllegalArgumentException("Matrices must not be null."); + } + if (matrixA.length == 0 || matrixB.length == 0) { + throw new IllegalArgumentException("Matrices must not be empty."); + } + if (matrixA[0].length != matrixB.length) { + throw new IllegalArgumentException("Matrix A columns must match Matrix B rows."); + } + return matrixProduct.multiply(matrixA, matrixB); } @Override public double[] multiply(double[][] matrix, double[] vector) { + + if (matrix.length == 0 || matrix[0].length == 0 || vector.length == 0) { + throw new IllegalArgumentException("Matrix and vector must not be empty."); + } + if (matrix[0].length != vector.length) { + throw new IllegalArgumentException("Matrix columns and vector size do not match."); + } + return matrixVectorProduct.multiply(matrix, vector); } } diff --git a/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixProduct.java b/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixProduct.java index 14569ba..b86d528 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixProduct.java +++ b/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixProduct.java @@ -1,30 +1,17 @@ package de.edux.core.math.matrix.parallel.operations; import de.edux.core.math.IMatrixProduct; - import java.util.stream.IntStream; public class MatrixProduct implements IMatrixProduct { - private void checkSizeForMultiplication(double[][] matrixA, double[][] matrixB) { - int m = matrixB.length; - int n = matrixA[0].length; - if (m != n) { - throw new RuntimeException( - "\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\""); - } - } - @Override public double[][] multiply(double[][] matrixA, double[][] matrixB) { - checkSizeForMultiplication(matrixA, matrixB); - int aRows = matrixA.length; int aColumns = matrixA[0].length; int bColumns = matrixB[0].length; double[][] result = new double[aRows][bColumns]; - IntStream.range(0, aRows) .parallel() .forEach( diff --git a/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixVectorProduct.java b/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixVectorProduct.java index 9b4628b..933b888 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixVectorProduct.java +++ b/lib/src/main/java/de/edux/core/math/matrix/parallel/operations/MatrixVectorProduct.java @@ -1,10 +1,25 @@ package de.edux.core.math.matrix.parallel.operations; import de.edux.core.math.IMatrixVectorProduct; +import java.util.stream.IntStream; public class MatrixVectorProduct implements IMatrixVectorProduct { + @Override public double[] multiply(double[][] matrix, double[] vector) { - return new double[0]; + int matrixRows = matrix.length; + int matrixColumns = matrix[0].length; + + double[] result = new double[matrixRows]; + IntStream.range(0, matrixRows) + .parallel() + .forEach( + row -> { + for (int col = 0; col < matrixColumns; col++) { + result[row] += matrix[row][col] * vector[col]; + } + }); + + return result; } } diff --git a/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixArithmeticTest.java b/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixArithmeticTest.java index 573e341..b0dc14f 100644 --- a/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixArithmeticTest.java +++ b/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixArithmeticTest.java @@ -2,11 +2,18 @@ import static org.junit.jupiter.api.Assertions.*; -import de.edux.core.math.IMatrixArithmetic; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; class MatrixArithmeticTest { + private static MatrixArithmetic matrixArithmetic; + + @BeforeAll + static void setUp() { + matrixArithmetic = new MatrixArithmetic(); + } + @Test public void testLargeMatrixMultiplication() { int matrixSize = 200; @@ -19,8 +26,8 @@ public void testLargeMatrixMultiplication() { matrixB[i][j] = 1; } } - IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic(); - double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); + + double[][] result = matrixArithmetic.multiply(matrixA, matrixB); for (int i = 0; i < matrixSize; i++) { for (int j = 0; j < matrixSize; j++) { assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); @@ -45,15 +52,13 @@ public void shouldThrowRuntimeExceptionForMatricesWithIncompatibleSizes() { {-1, 6}, }; - IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic(); - assertAll( () -> assertThrows( - RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixA, matrixB)), + RuntimeException.class, () -> matrixArithmetic.multiply(matrixA, matrixB)), () -> assertThrows( - RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixC, matrixA))); + RuntimeException.class, () -> matrixArithmetic.multiply(matrixC, matrixA))); } @Test @@ -89,9 +94,7 @@ public void shouldMultiply8x8MatricesCorrectly() { {22, 18, 14, 10, 22, 18, 14, 10} }; - IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic(); - double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); - + double[][] result = matrixArithmetic.multiply(matrixA, matrixB); assertArrayEquals( expected, result, "The 8x8 matrix multiplication did not yield the correct result."); } @@ -111,9 +114,7 @@ public void shouldMultiplyWithZeroMatrix() { {0, 0} }; - IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic(); - double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); - + double[][] result = matrixArithmetic.multiply(matrixA, matrixB); assertArrayEquals(expected, result); } @@ -128,9 +129,7 @@ public void shouldMultiplyWithIdentityMatrix() { {7, 8} }; - IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic(); - double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); - + double[][] result = matrixArithmetic.multiply(matrixA, matrixB); assertArrayEquals(matrixB, result); } @@ -149,9 +148,56 @@ public void shouldMultiplySmallMatricesCorrectly() { {10, 8} }; - IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic(); - double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); - + double[][] result = matrixArithmetic.multiply(matrixA, matrixB); assertArrayEquals(expected, result); } + + @Test + void shouldSolveMatrixVectorProduct() { + int matrixSize = 2048; + double[][] matrixA = new double[matrixSize][matrixSize]; + double[] vector = new double[matrixSize]; + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + matrixA[i][j] = 1; + vector[i] = 1; + } + } + + double[] resultVector = matrixArithmetic.multiply(matrixA, vector); + for (int i = 0; i < matrixSize; i++) { + assertEquals(matrixSize, resultVector[i], "Result on [" + i + "][" + i + "] not correct."); + } + } + + @Test + void shouldHandleEmptyMatrix() { + double[][] matrix = new double[0][0]; + double[] vector = new double[0]; + assertThrows(IllegalArgumentException.class, () -> matrixArithmetic.multiply(matrix, vector)); + } + + @Test + void shouldHandleMismatchedSizes() { + double[][] matrix = {{1, 2, 3}, {4, 5, 6}}; + double[] vector = {1, 2}; + assertThrows(IllegalArgumentException.class, () -> matrixArithmetic.multiply(matrix, vector)); + } + + @Test + void shouldHandleNullMatrix() { + double[] vector = {1, 2, 3}; + assertThrows(NullPointerException.class, () -> matrixArithmetic.multiply(null, vector)); + } + + @Test + void shouldMultiplyVectorWithIdentityMatrix() { + double[][] matrix = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}; + double[] vector = {1, 2, 3}; + double[] expected = {1, 2, 3}; + double[] result = matrixArithmetic.multiply(matrix, vector); + assertArrayEquals( + expected, result, "Multiplying with identity matrix should return the original vector."); + } }