From e92014911df8c500f8ecfb80466126a7d530a7bd Mon Sep 17 00:00:00 2001 From: Samuel Abramov Date: Mon, 6 Nov 2023 02:04:15 +0100 Subject: [PATCH] feat():Implement Matrix Multiplication Using Strassen Algorithm #75 --- gradle.properties | 1 + .../strassen/StrassenMultiplicatIMatrix.java | 10 -- .../matrix/strassen/StrassenParallel.java | 2 +- .../strassen/StrassenParallelInplace.java | 19 --- .../strassen/StrassenParallelInplaceTest.java | 33 ----- .../matrix/strassen/StrassenParallelTest.java | 101 ++++++++++++++- .../math/matrix/strassen/StrassenTest.java | 122 +++++++++++++++--- 7 files changed, 209 insertions(+), 79 deletions(-) delete mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java delete mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java delete mode 100644 lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java diff --git a/gradle.properties b/gradle.properties index 4f996f1..950d4eb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,3 @@ org.gradle.parallel=true org.gradle.caching=true +org.gradle.jvmargs=-Xmx32g -Xms8g diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java deleted file mode 100644 index 9166756..0000000 --- a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java +++ /dev/null @@ -1,10 +0,0 @@ -package de.edux.core.math.matrix.strassen; - -import de.edux.core.math.IMatrixArithmetic; - -public class StrassenMultiplicatIMatrix implements IMatrixArithmetic { - @Override - public double[][] multiply(double[][] matrixA, double[][] matrixB) { - return new double[0][]; - } -} diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java index f3b5490..2a40dfa 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java +++ b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java @@ -103,7 +103,7 @@ private class StrassenTask extends RecursiveTask { protected double[][] compute() { int n = A.length; - if (n <= 64) { + if (n <= 512) { return conventionalMultiply(A, B); } else { int newSize = n / 2; diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java deleted file mode 100644 index db6403b..0000000 --- a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java +++ /dev/null @@ -1,19 +0,0 @@ -package de.edux.core.math.matrix.strassen; - -import de.edux.core.math.IMatrixArithmetic; -import java.util.concurrent.ForkJoinPool; - -public class StrassenParallelInplace implements IMatrixArithmetic { - private ForkJoinPool forkJoinPool = new ForkJoinPool(); - - @Override - public double[][] multiply(double[][] matrixA, double[][] matrixB) {} - - private int nextPowerOfTwo(int number) { - int power = 1; - while (power < number) { - power *= 2; - } - return power; - } -} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java deleted file mode 100644 index bf42a2e..0000000 --- a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java +++ /dev/null @@ -1,33 +0,0 @@ -package de.edux.core.math.matrix.strassen; - -import static org.junit.jupiter.api.Assertions.*; - -import de.edux.core.math.IMatrixArithmetic; -import org.junit.jupiter.api.Test; - -class StrassenParallelInplaceTest { - - @Test - public void shouldMultiplyWithParallelStrassen() { - int matrixSize = 2000; - double[][] matrixA = new double[matrixSize][matrixSize]; - double[][] matrixB = new double[matrixSize][matrixSize]; - - for (int i = 0; i < matrixSize; i++) { - for (int j = 0; j < matrixSize; j++) { - matrixA[i][j] = 1; - matrixB[i][j] = 1; - } - } - - IMatrixArithmetic strassenParallel = new StrassenParallelInplace(); - - double[][] result = strassenParallel.multiply(matrixA, matrixB); - - for (int i = 0; i < matrixSize; i++) { - for (int j = 0; j < matrixSize; j++) { - assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); - } - } - } -} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java index a14a83d..5fe6f77 100644 --- a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java @@ -9,7 +9,7 @@ class StrassenParallelTest { @Test public void shouldMultiplyWithParallelStrassen() { - int matrixSize = 2000; + int matrixSize = 2048; double[][] matrixA = new double[matrixSize][matrixSize]; double[][] matrixB = new double[matrixSize][matrixSize]; @@ -30,4 +30,103 @@ public void shouldMultiplyWithParallelStrassen() { } } } + + @Test + public void shouldMultiply8x8MatricesCorrectly() { + double[][] matrixA = { + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 4, 3, 2, 1}, + {2, 3, 4, 5, 6, 7, 8, 9}, + {9, 8, 7, 6, 5, 4, 3, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {1, 3, 5, 7, 9, 11, 13, 15}, + {15, 13, 11, 9, 7, 5, 3, 1} + }; + double[][] matrixB = { + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1} + }; + double[][] expected = { + {6, 8, 10, 12, 6, 8, 10, 12}, + {12, 10, 8, 6, 12, 10, 8, 6}, + {8, 10, 12, 14, 8, 10, 12, 14}, + {14, 12, 10, 8, 14, 12, 10, 8}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {4, 4, 4, 4, 4, 4, 4, 4}, + {10, 14, 18, 22, 10, 14, 18, 22}, + {22, 18, 14, 10, 22, 18, 14, 10} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals( + expected, result, "The 8x8 matrix multiplication did not yield the correct result."); + } + + @Test + public void shouldMultiplyWithZeroMatrix() { + double[][] matrixA = { + {0, 0}, + {0, 0} + }; + double[][] matrixB = { + {1, 2}, + {3, 4} + }; + double[][] expected = { + {0, 0}, + {0, 0} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } + + @Test + public void shouldMultiplyWithIdentityMatrix() { + double[][] matrixA = { + {1, 0}, + {0, 1} + }; + double[][] matrixB = { + {5, 6}, + {7, 8} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals(matrixB, result); + } + + @Test + public void shouldMultiplySmallMatricesCorrectly() { + double[][] matrixA = { + {1, 2}, + {3, 4} + }; + double[][] matrixB = { + {2, 0}, + {1, 2} + }; + double[][] expected = { + {4, 4}, + {10, 8} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } } diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java index 9c73917..1f8c2cd 100644 --- a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java @@ -8,12 +8,13 @@ class StrassenTest { @Test - public void testLargeMatrixMultiplication() { - double[][] matrixA = new double[1000][1000]; - double[][] matrixB = new double[1000][1000]; + public void shouldMultiplyWithParallelStrassen() { + int matrixSize = 128; + double[][] matrixA = new double[matrixSize][matrixSize]; + double[][] matrixB = new double[matrixSize][matrixSize]; - for (int i = 0; i < 1000; i++) { - for (int j = 0; j < 1000; j++) { + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { matrixA[i][j] = 1; matrixB[i][j] = 1; } @@ -22,19 +23,110 @@ public void testLargeMatrixMultiplication() { IMatrixArithmetic strassen = new Strassen(); double[][] result = strassen.multiply(matrixA, matrixB); - for (int i = 0; i < 1000; i++) { - for (int j = 0; j < 1000; j++) { - assertEquals(1000, result[i][j], "Result on [" + i + "][" + j + "] not correct."); + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); } } } - private void printMatrix(double[][] result) { - for (int i = 0; i < result.length; i++) { - for (int j = 0; j < result.length; j++) { - System.out.print(result[i][j] + " "); - } - System.out.println(); - } + @Test + public void shouldMultiply8x8MatricesCorrectly() { + double[][] matrixA = { + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 4, 3, 2, 1}, + {2, 3, 4, 5, 6, 7, 8, 9}, + {9, 8, 7, 6, 5, 4, 3, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {1, 3, 5, 7, 9, 11, 13, 15}, + {15, 13, 11, 9, 7, 5, 3, 1} + }; + double[][] matrixB = { + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1} + }; + double[][] expected = { + {6, 8, 10, 12, 6, 8, 10, 12}, + {12, 10, 8, 6, 12, 10, 8, 6}, + {8, 10, 12, 14, 8, 10, 12, 14}, + {14, 12, 10, 8, 14, 12, 10, 8}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {4, 4, 4, 4, 4, 4, 4, 4}, + {10, 14, 18, 22, 10, 14, 18, 22}, + {22, 18, 14, 10, 22, 18, 14, 10} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals( + expected, result, "The 8x8 matrix multiplication did not yield the correct result."); + } + + @Test + public void shouldMultiplyWithZeroMatrix() { + double[][] matrixA = { + {0, 0}, + {0, 0} + }; + double[][] matrixB = { + {1, 2}, + {3, 4} + }; + double[][] expected = { + {0, 0}, + {0, 0} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } + + @Test + public void shouldMultiplyWithIdentityMatrix() { + double[][] matrixA = { + {1, 0}, + {0, 1} + }; + double[][] matrixB = { + {5, 6}, + {7, 8} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals(matrixB, result); + } + + @Test + public void shouldMultiplySmallMatricesCorrectly() { + double[][] matrixA = { + {1, 2}, + {3, 4} + }; + double[][] matrixB = { + {2, 0}, + {1, 2} + }; + double[][] expected = { + {4, 4}, + {10, 8} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); } }