diff --git a/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java b/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java index ff809f6..a4fe389 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java +++ b/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java @@ -1,44 +1,39 @@ package de.edux.core.math.matrix.classic; import de.edux.core.math.IMatrixArithmetic; -import java.util.Arrays; +import java.util.stream.IntStream; public class MatrixArithmetic implements IMatrixArithmetic { - @Override - public double[][] multiply(double[][] matrixA, double[][] matrixB) { - int columnsInFirstMatrix = matrixA[0].length; - int rowsInSecondMatrix = matrixB.length; - if (columnsInFirstMatrix != rowsInSecondMatrix) { + + private void checkSizeForMultiplication(double[][] matrixA, double[][] matrixB) { + int m = matrixB.length; + int n = matrixA[0].length; + if (m != n) { throw new RuntimeException( "\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\""); } - - int rowsInMatrixC = matrixA.length; - int columnsInMatrixC = matrixB.length; - double[][] matrixC = new double[rowsInMatrixC][columnsInMatrixC]; - for (int row = 0; row < rowsInMatrixC; row++) { - for (int column = 0; column < columnsInMatrixC; column++) { - matrixC[row][column] = - vectorDotProduct(getRowVector(matrixA, row), getColumnVector(matrixB, column)); - } - } - return matrixC; } - public double[] getRowVector(double[][] matrix, int rowIndex) { - return matrix[rowIndex]; - } + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) { + checkSizeForMultiplication(matrixA, matrixB); - public double[] getColumnVector(double[][] matrix, int columnIndex) { - return Arrays.stream(matrix).mapToDouble((row) -> row[columnIndex]).toArray(); - } + int aRows = matrixA.length; + int aColumns = matrixA[0].length; + int bColumns = matrixB[0].length; - public double vectorDotProduct(double[] rowVector, double[] columVector) { - int vectorSize = rowVector.length; - double vectorDotProduct = 0; - for (int i = 0; i < vectorSize; i++) { - vectorDotProduct += rowVector[i] * columVector[i]; - } - return vectorDotProduct; + double[][] result = new double[aRows][bColumns]; + + IntStream.range(0, aRows) + .forEach( + row -> { + for (int col = 0; col < bColumns; col++) { + for (int i = 0; i < aColumns; i++) { + result[row][col] += matrixA[row][i] * matrixB[i][col]; + } + } + }); + + return result; } } diff --git a/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java b/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java index 2986ff3..c312fb8 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java +++ b/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java @@ -1,10 +1,40 @@ package de.edux.core.math.matrix.parallel; import de.edux.core.math.IMatrixArithmetic; +import java.util.stream.IntStream; public class MatrixParallelArithmetic implements IMatrixArithmetic { + + private void checkSizeForMultiplication(double[][] matrixA, double[][] matrixB) { + int m = matrixB.length; + int n = matrixA[0].length; + if (m != n) { + throw new RuntimeException( + "\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\""); + } + } + @Override public double[][] multiply(double[][] matrixA, double[][] matrixB) { - return new double[0][]; + checkSizeForMultiplication(matrixA, matrixB); + + int aRows = matrixA.length; + int aColumns = matrixA[0].length; + int bColumns = matrixB[0].length; + + double[][] result = new double[aRows][bColumns]; + + IntStream.range(0, aRows) + .parallel() + .forEach( + row -> { + for (int col = 0; col < bColumns; col++) { + for (int i = 0; i < aColumns; i++) { + result[row][col] += matrixA[row][i] * matrixB[i][col]; + } + } + }); + + return result; } } diff --git a/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmeticTest.java b/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmeticTest.java new file mode 100644 index 0000000..c1368b3 --- /dev/null +++ b/lib/src/test/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmeticTest.java @@ -0,0 +1,157 @@ +package de.edux.core.math.matrix.parallel; + +import static org.junit.jupiter.api.Assertions.*; + +import de.edux.core.math.IMatrixArithmetic; +import org.junit.jupiter.api.Test; + +class MatrixParallelArithmeticTest { + + @Test + public void testLargeMatrixMultiplication() { + int matrixSize = 200; + double[][] matrixA = new double[matrixSize][matrixSize]; + double[][] matrixB = new double[matrixSize][matrixSize]; + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + matrixA[i][j] = 1; + matrixB[i][j] = 1; + } + } + IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic(); + double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); + } + } + } + + @Test + public void shouldThrowRuntimeExceptionForMatricesWithIncompatibleSizes() { + double[][] matrixA = { + {3, -5, 1}, + {-2, 0, 4}, + {-1, 6, 5}, + }; + double[][] matrixB = { + {7, 2, 4}, + {0, 1, -5}, + }; + double[][] matrixC = { + {3, -5}, + {-2, 0}, + {-1, 6}, + }; + + IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic(); + + assertAll( + () -> + assertThrows( + RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixA, matrixB)), + () -> + assertThrows( + RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixC, matrixA))); + } + + @Test + public void shouldMultiply8x8MatricesCorrectly() { + double[][] matrixA = { + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 4, 3, 2, 1}, + {2, 3, 4, 5, 6, 7, 8, 9}, + {9, 8, 7, 6, 5, 4, 3, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {1, 3, 5, 7, 9, 11, 13, 15}, + {15, 13, 11, 9, 7, 5, 3, 1} + }; + double[][] matrixB = { + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1} + }; + double[][] expected = { + {6, 8, 10, 12, 6, 8, 10, 12}, + {12, 10, 8, 6, 12, 10, 8, 6}, + {8, 10, 12, 14, 8, 10, 12, 14}, + {14, 12, 10, 8, 14, 12, 10, 8}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {4, 4, 4, 4, 4, 4, 4, 4}, + {10, 14, 18, 22, 10, 14, 18, 22}, + {22, 18, 14, 10, 22, 18, 14, 10} + }; + + IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic(); + double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); + + assertArrayEquals( + expected, result, "The 8x8 matrix multiplication did not yield the correct result."); + } + + @Test + public void shouldMultiplyWithZeroMatrix() { + double[][] matrixA = { + {0, 0}, + {0, 0} + }; + double[][] matrixB = { + {1, 2}, + {3, 4} + }; + double[][] expected = { + {0, 0}, + {0, 0} + }; + + IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic(); + double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } + + @Test + public void shouldMultiplyWithIdentityMatrix() { + double[][] matrixA = { + {1, 0}, + {0, 1} + }; + double[][] matrixB = { + {5, 6}, + {7, 8} + }; + + IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic(); + double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); + + assertArrayEquals(matrixB, result); + } + + @Test + public void shouldMultiplySmallMatricesCorrectly() { + double[][] matrixA = { + {1, 2}, + {3, 4} + }; + double[][] matrixB = { + {2, 0}, + {1, 2} + }; + double[][] expected = { + {4, 4}, + {10, 8} + }; + + IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic(); + double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } +}