From 446c02de4ad9ebe0cbcb24e25db9c0e23dfb95bd Mon Sep 17 00:00:00 2001 From: Samuel Abramov Date: Fri, 3 Nov 2023 13:04:52 +0100 Subject: [PATCH 1/3] feat(#75): structure --- .../de/edux/core/math/IMatrixArithmetic.java | 5 + .../math/matrix/classic/MatrixArithmetic.java | 10 ++ .../parallel/MatrixParallelArithmetic.java | 10 ++ .../strassen/StrassenMultiplicatIMatrix.java | 10 ++ lib/src/main/java/de/edux/math/Entity.java | 12 -- lib/src/main/java/de/edux/math/MathUtil.java | 15 -- lib/src/main/java/de/edux/math/Matrix.java | 154 ------------------ .../main/java/de/edux/math/Validations.java | 16 -- lib/src/main/java/de/edux/math/Vector.java | 148 ----------------- .../java/de/edux/math/entity/MatrixTest.java | 75 --------- .../java/de/edux/math/entity/VectorTest.java | 45 ----- 11 files changed, 35 insertions(+), 465 deletions(-) create mode 100644 lib/src/main/java/de/edux/core/math/IMatrixArithmetic.java create mode 100644 lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java create mode 100644 lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java create mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java delete mode 100644 lib/src/main/java/de/edux/math/Entity.java delete mode 100644 lib/src/main/java/de/edux/math/MathUtil.java delete mode 100644 lib/src/main/java/de/edux/math/Matrix.java delete mode 100644 lib/src/main/java/de/edux/math/Validations.java delete mode 100644 lib/src/main/java/de/edux/math/Vector.java delete mode 100644 lib/src/test/java/de/edux/math/entity/MatrixTest.java delete mode 100644 lib/src/test/java/de/edux/math/entity/VectorTest.java diff --git a/lib/src/main/java/de/edux/core/math/IMatrixArithmetic.java b/lib/src/main/java/de/edux/core/math/IMatrixArithmetic.java new file mode 100644 index 0000000..14ce2a0 --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/IMatrixArithmetic.java @@ -0,0 +1,5 @@ +package de.edux.core.math; + +public interface IMatrixArithmetic { + double[][] multiply(double[][] matrixA, double[][] matrixB); +} diff --git a/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java b/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java new file mode 100644 index 0000000..9bbf8ad --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java @@ -0,0 +1,10 @@ +package de.edux.core.math.matrix.classic; + +import de.edux.core.math.IMatrixArithmetic; + +public class MatrixArithmetic implements IMatrixArithmetic { + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) { + return new double[0][]; + } +} diff --git a/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java b/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java new file mode 100644 index 0000000..2986ff3 --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java @@ -0,0 +1,10 @@ +package de.edux.core.math.matrix.parallel; + +import de.edux.core.math.IMatrixArithmetic; + +public class MatrixParallelArithmetic implements IMatrixArithmetic { + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) { + return new double[0][]; + } +} diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java new file mode 100644 index 0000000..9166756 --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java @@ -0,0 +1,10 @@ +package de.edux.core.math.matrix.strassen; + +import de.edux.core.math.IMatrixArithmetic; + +public class StrassenMultiplicatIMatrix implements IMatrixArithmetic { + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) { + return new double[0][]; + } +} diff --git a/lib/src/main/java/de/edux/math/Entity.java b/lib/src/main/java/de/edux/math/Entity.java deleted file mode 100644 index 04d3794..0000000 --- a/lib/src/main/java/de/edux/math/Entity.java +++ /dev/null @@ -1,12 +0,0 @@ -package de.edux.math; - -public interface Entity { - - T add(T another); - - T subtract(T another); - - T multiply(T another); - - T scalarMultiply(double n); -} diff --git a/lib/src/main/java/de/edux/math/MathUtil.java b/lib/src/main/java/de/edux/math/MathUtil.java deleted file mode 100644 index d7963d4..0000000 --- a/lib/src/main/java/de/edux/math/MathUtil.java +++ /dev/null @@ -1,15 +0,0 @@ -package de.edux.math; - -public final class MathUtil { - - public static double[] unwrap(double[][] matrix) { - double[] result = new double[matrix.length * matrix[0].length]; - int i = 0; - for (double[] arr : matrix) { - for (double val : arr) { - result[i++] = val; - } - } - return result; - } -} diff --git a/lib/src/main/java/de/edux/math/Matrix.java b/lib/src/main/java/de/edux/math/Matrix.java deleted file mode 100644 index c8f63ae..0000000 --- a/lib/src/main/java/de/edux/math/Matrix.java +++ /dev/null @@ -1,154 +0,0 @@ -package de.edux.math; - -import java.util.Iterator; -import java.util.NoSuchElementException; - -public class Matrix implements Entity, Iterable { - - private final double[][] raw; - - public Matrix(double[][] matrix) { - this.raw = matrix; - } - - @Override - public Matrix add(Matrix another) { - return add(another.raw()); - } - - public Matrix add(double[][] another) { - Validations.sizeMatrix(raw, another); - - double[][] result = new double[raw.length][raw[0].length]; - - for (int i = 0; i < result.length; i++) { - for (int a = 0; a < result[0].length; a++) { - result[i][a] = raw[i][a] + another[i][a]; - } - } - - return new Matrix(result); - } - - @Override - public Matrix subtract(Matrix another) { - return subtract(another.raw()); - } - - @Override - public Matrix multiply(Matrix another) { - return multiply(another.raw()); - } - - public Matrix multiply(double[][] another) { - return null; // TODO optimized algorithm for matrix multiplication - } - - @Override - public Matrix scalarMultiply(double n) { - double[][] result = new double[raw.length][raw[0].length]; - - for (int i = 0; i < result.length; i++) { - for (int a = 0; a < result[0].length; a++) { - result[i][a] = raw[i][a] * n; - } - } - - return new Matrix(result); - } - - public Matrix subtract(double[][] another) { - Validations.sizeMatrix(raw, another); - - double[][] result = new double[raw.length][raw[0].length]; - - for (int i = 0; i < result.length; i++) { - for (int a = 0; a < result[0].length; a++) { - result[i][a] = raw[i][a] - another[i][a]; - } - } - - return new Matrix(result); - } - - public boolean isSquare() { - return rows() == columns(); - } - - public int rows() { - return raw.length; - } - - public int columns() { - return raw[0].length; - } - - public double[][] raw() { - return raw; - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof Matrix matrix) { - if (matrix.rows() != rows() || matrix.columns() != columns()) { - return false; - } - for (int i = 0; i < raw.length; i++) { - for (int a = 0; a < raw[i].length; a++) { - if (matrix.raw()[i][a] != raw[i][a]) { - return false; - } - } - } - return true; - } - return false; - } - - @Override - public String toString() { - StringBuilder builder = new StringBuilder("[").append("\n"); - for (int i = 0; i < raw.length; i++) { - builder.append(" ").append("["); - for (int a = 0; a < raw[i].length; a++) { - builder.append(raw[i][a]); - if (a != raw[i].length - 1) { - builder.append(", "); - } - } - builder.append("]"); - if (i != raw.length - 1) { - builder.append(","); - } - builder.append("\n"); - } - return builder.append("]").toString(); - } - - @Override - public Iterator iterator() { - return new MatrixIterator(raw); - } - - public static class MatrixIterator implements Iterator { - - private final double[] data; - private int current; - - public MatrixIterator(double[][] data) { - this.data = MathUtil.unwrap(data); - this.current = 0; - } - - @Override - public boolean hasNext() { - return current < data.length; - } - - @Override - public Double next() { - if (!hasNext()) throw new NoSuchElementException(); - return data[current++]; - } - } -} diff --git a/lib/src/main/java/de/edux/math/Validations.java b/lib/src/main/java/de/edux/math/Validations.java deleted file mode 100644 index 076a2e9..0000000 --- a/lib/src/main/java/de/edux/math/Validations.java +++ /dev/null @@ -1,16 +0,0 @@ -package de.edux.math; - -public final class Validations { - - public static void size(double[] first, double[] second) { - if (first.length != second.length) { - throw new IllegalArgumentException("sizes mismatch"); - } - } - - public static void sizeMatrix(double[][] first, double[][] second) { - if (first.length != second.length || first[0].length != second[0].length) { - throw new IllegalArgumentException("sizes mismatch"); - } - } -} diff --git a/lib/src/main/java/de/edux/math/Vector.java b/lib/src/main/java/de/edux/math/Vector.java deleted file mode 100644 index 66730bf..0000000 --- a/lib/src/main/java/de/edux/math/Vector.java +++ /dev/null @@ -1,148 +0,0 @@ -package de.edux.math; - -import java.util.Arrays; -import java.util.Iterator; -import java.util.NoSuchElementException; - -public class Vector implements Entity, Iterable { - - private final double[] raw; - - public Vector(double[] vector) { - this.raw = vector; - } - - @Override - public Vector add(Vector another) { - return add(another.raw()); - } - - public Vector add(double[] another) { - Validations.size(raw, another); - - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] + another[i]; - } - - return new Vector(result); - } - - @Override - public Vector subtract(Vector another) { - return subtract(another.raw()); - } - - public Vector subtract(double[] another) { - Validations.size(raw, another); - - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] - another[i]; - } - - return new Vector(result); - } - - @Override - public Vector multiply(Vector another) { - return multiply(another.raw()); - } - - public Vector multiply(double[] another) { - Validations.size(raw, another); - - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] * another[i]; - if (result[i] == 0) { // Avoiding -0 result - result[i] = 0; - } - } - - return new Vector(result); - } - - @Override - public Vector scalarMultiply(double n) { - double[] result = new double[length()]; - for (int i = 0; i < result.length; i++) { - result[i] = raw[i] * n; - if (result[i] == 0) { // Avoiding -0 result - result[i] = 0; - } - } - - return new Vector(result); - } - - public double dot(Vector another) { - return dot(another.raw()); - } - - public double dot(double[] another) { - Validations.size(raw, another); - - double result = 0; - for (int i = 0; i < raw.length; i++) { - result += raw[i] * another[i]; - } - - return result; - } - - public int length() { - return raw.length; - } - - public double[] raw() { - return raw.clone(); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof Vector) { - return Arrays.equals(raw, ((Vector) obj).raw()); - } - return false; - } - - @Override - public String toString() { - StringBuilder builder = new StringBuilder("["); - for (int i = 0; i < raw.length; i++) { - builder.append(raw[i]); - if (i != raw.length - 1) { - builder.append(", "); - } - } - return builder.append("]").toString(); - } - - @Override - public Iterator iterator() { - return new VectorIterator(raw); - } - - public static class VectorIterator implements Iterator { - - private final double[] data; - private int current; - - public VectorIterator(double[] data) { - this.data = data; - this.current = 0; - } - - @Override - public boolean hasNext() { - return current < data.length; - } - - @Override - public Double next() { - if (!hasNext()) throw new NoSuchElementException(); - return data[current++]; - } - } -} diff --git a/lib/src/test/java/de/edux/math/entity/MatrixTest.java b/lib/src/test/java/de/edux/math/entity/MatrixTest.java deleted file mode 100644 index cedfa93..0000000 --- a/lib/src/test/java/de/edux/math/entity/MatrixTest.java +++ /dev/null @@ -1,75 +0,0 @@ -package de.edux.math.entity; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import de.edux.math.Matrix; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -public class MatrixTest { - - static Matrix first; - static Matrix second; - - @BeforeEach - public void init() { - first = - new Matrix( - new double[][] { - {5, 3, -1}, - {-2, 0, 6}, - {5, 1, -9} - }); - second = - new Matrix( - new double[][] { - {8, 7, 4}, - {1, -5, 2}, - {0, 3, 0} - }); - } - - @Test - public void testAdd() { - assertEquals( - new Matrix( - new double[][] { - {13, 10, 3}, - {-1, -5, 8}, - {5, 4, -9} - }), - first.add(second)); - } - - @Test - public void testSubtract() { - assertEquals( - new Matrix( - new double[][] { - {-3, -4, -5}, - {-3, 5, 4}, - {5, -2, -9} - }), - first.subtract(second)); - } - - @Test - public void testScalarMultiply() { - assertEquals( - new Matrix( - new double[][] { - {20, 12, -4}, - {-8, 0, 24}, - {20, 4, -36} - }), - first.scalarMultiply(4)); - assertEquals( - new Matrix( - new double[][] { - {-48, -42, -24}, - {-6, 30, -12}, - {0, -18, 0} - }), - second.scalarMultiply(-6)); - } -} diff --git a/lib/src/test/java/de/edux/math/entity/VectorTest.java b/lib/src/test/java/de/edux/math/entity/VectorTest.java deleted file mode 100644 index e7c90d5..0000000 --- a/lib/src/test/java/de/edux/math/entity/VectorTest.java +++ /dev/null @@ -1,45 +0,0 @@ -package de.edux.math.entity; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import de.edux.math.Vector; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -public class VectorTest { - - static Vector first; - static Vector second; - - @BeforeAll - public static void init() { - first = new Vector(new double[] {1, 5, 4}); - second = new Vector(new double[] {3, 8, 0}); - } - - @Test - public void testAdd() { - assertEquals(new Vector(new double[] {4, 13, 4}), first.add(second)); - } - - @Test - public void testSubtract() { - assertEquals(new Vector(new double[] {-2, -3, 4}), first.subtract(second)); - } - - @Test - public void testMultiply() { - assertEquals(new Vector(new double[] {3, 40, 0}), first.multiply(second)); - } - - @Test - public void testScalarMultiply() { - assertEquals(new Vector(new double[] {3, 15, 12}), first.scalarMultiply(3)); // first by 3 - assertEquals(new Vector(new double[] {-6, -16, 0}), second.scalarMultiply(-2)); // second by -2 - } - - @Test - public void testDot() { - assertEquals(43, first.dot(second)); - } -} From 193845e707c7bae890acd7cba49a0af706edc6d3 Mon Sep 17 00:00:00 2001 From: Samuel Abramov Date: Fri, 3 Nov 2023 13:04:52 +0100 Subject: [PATCH 2/3] feat(#75): structure --- .../core/math/matrix/strassen/Strassen.java | 123 +++++++++++++ .../matrix/strassen/StrassenParallel.java | 162 ++++++++++++++++++ .../strassen/StrassenParallelInplace.java | 19 ++ .../strassen/StrassenParallelInplaceTest.java | 33 ++++ .../matrix/strassen/StrassenParallelTest.java | 33 ++++ .../math/matrix/strassen/StrassenTest.java | 40 +++++ 6 files changed, 410 insertions(+) create mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/Strassen.java create mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java create mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java create mode 100644 lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java create mode 100644 lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java create mode 100644 lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/Strassen.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/Strassen.java new file mode 100644 index 0000000..0d2880d --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/matrix/strassen/Strassen.java @@ -0,0 +1,123 @@ +package de.edux.core.math.matrix.strassen; + +import de.edux.core.math.IMatrixArithmetic; + +public class Strassen implements IMatrixArithmetic { + + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) { + int n = matrixA.length; + int m = nextPowerOfTwo(n); + double[][] extendedMatrixA = new double[m][m]; + double[][] extendedMatrixB = new double[m][m]; + + for (int i = 0; i < n; i++) { + System.arraycopy(matrixA[i], 0, extendedMatrixA[i], 0, matrixA[i].length); + System.arraycopy(matrixB[i], 0, extendedMatrixB[i], 0, matrixB[i].length); + } + + double[][] extendedResult = strassen(extendedMatrixA, extendedMatrixB); + + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + System.arraycopy(extendedResult[i], 0, result[i], 0, n); + } + + return result; + } + + private double[][] strassen(double[][] A, double[][] B) { + int n = A.length; + + double[][] result = new double[n][n]; + + if (n == 1) { + result[0][0] = A[0][0] * B[0][0]; + } else { + int newSize = n / 2; + double[][] a11 = new double[newSize][newSize]; + double[][] a12 = new double[newSize][newSize]; + double[][] a21 = new double[newSize][newSize]; + double[][] a22 = new double[newSize][newSize]; + double[][] b11 = new double[newSize][newSize]; + double[][] b12 = new double[newSize][newSize]; + double[][] b21 = new double[newSize][newSize]; + double[][] b22 = new double[newSize][newSize]; + + divideMatrix(A, a11, 0, 0); + divideMatrix(A, a12, 0, newSize); + divideMatrix(A, a21, newSize, 0); + divideMatrix(A, a22, newSize, newSize); + divideMatrix(B, b11, 0, 0); + divideMatrix(B, b12, 0, newSize); + divideMatrix(B, b21, newSize, 0); + divideMatrix(B, b22, newSize, newSize); + + double[][] m1 = strassen(addMatrices(a11, a22), addMatrices(b11, b22)); + double[][] m2 = strassen(addMatrices(a21, a22), b11); + double[][] m3 = strassen(a11, subtractMatrices(b12, b22)); + double[][] m4 = strassen(a22, subtractMatrices(b21, b11)); + double[][] m5 = strassen(addMatrices(a11, a12), b22); + double[][] m6 = strassen(subtractMatrices(a21, a11), addMatrices(b11, b12)); + double[][] m7 = strassen(subtractMatrices(a12, a22), addMatrices(b21, b22)); + + double[][] c11 = addMatrices(subtractMatrices(addMatrices(m1, m4), m5), m7); + double[][] c12 = addMatrices(m3, m5); + double[][] c21 = addMatrices(m2, m4); + double[][] c22 = addMatrices(subtractMatrices(addMatrices(m1, m3), m2), m6); + + combineMatrix(c11, result, 0, 0); + combineMatrix(c12, result, 0, newSize); + combineMatrix(c21, result, newSize, 0); + combineMatrix(c22, result, newSize, newSize); + } + + return result; + } + + private void divideMatrix(double[][] parent, double[][] child, int iB, int jB) { + for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { + for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { + child[i1][j1] = parent[i2][j2]; + } + } + } + + private void combineMatrix(double[][] child, double[][] parent, int iB, int jB) { + for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { + for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { + parent[i2][j2] = child[i1][j1]; + } + } + } + + private double[][] addMatrices(double[][] a, double[][] b) { + int n = a.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + result[i][j] = a[i][j] + b[i][j]; + } + } + return result; + } + + private double[][] subtractMatrices(double[][] a, double[][] b) { + int n = a.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + result[i][j] = a[i][j] - b[i][j]; + } + } + return result; + } + + private int nextPowerOfTwo(int number) { + int power = 1; + while (power < number) { + power *= 2; + } + return power; + } +} diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java new file mode 100644 index 0000000..f3b5490 --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java @@ -0,0 +1,162 @@ +package de.edux.core.math.matrix.strassen; + +import de.edux.core.math.IMatrixArithmetic; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.RecursiveTask; + +public class StrassenParallel implements IMatrixArithmetic { + + private ForkJoinPool forkJoinPool = new ForkJoinPool(4); + + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) { + int n = matrixA.length; + int m = nextPowerOfTwo(n); + double[][] extendedMatrixA = new double[m][m]; + double[][] extendedMatrixB = new double[m][m]; + + for (int i = 0; i < n; i++) { + System.arraycopy(matrixA[i], 0, extendedMatrixA[i], 0, matrixA[i].length); + System.arraycopy(matrixB[i], 0, extendedMatrixB[i], 0, matrixB[i].length); + } + + double[][] extendedResult = + forkJoinPool.invoke(new StrassenTask(extendedMatrixA, extendedMatrixB)); + + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + System.arraycopy(extendedResult[i], 0, result[i], 0, n); + } + + return result; + } + + private double[][] conventionalMultiply(double[][] A, double[][] B) { + int n = A.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + for (int k = 0; k < n; k++) { + result[i][j] += A[i][k] * B[k][j]; + } + } + } + return result; + } + + private int nextPowerOfTwo(int number) { + int power = 1; + while (power < number) { + power *= 2; + } + return power; + } + + private void divideMatrix(double[][] parent, double[][] child, int iB, int jB) { + for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { + for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { + child[i1][j1] = parent[i2][j2]; + } + } + } + + private void combineMatrix(double[][] child, double[][] parent, int iB, int jB) { + for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { + for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { + parent[i2][j2] = child[i1][j1]; + } + } + } + + private double[][] addMatrices(double[][] a, double[][] b) { + int n = a.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + result[i][j] = a[i][j] + b[i][j]; + } + } + return result; + } + + private double[][] subtractMatrices(double[][] a, double[][] b) { + int n = a.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + result[i][j] = a[i][j] - b[i][j]; + } + } + return result; + } + + private class StrassenTask extends RecursiveTask { + private double[][] A; + private double[][] B; + + StrassenTask(double[][] A, double[][] B) { + this.A = A; + this.B = B; + } + + @Override + protected double[][] compute() { + int n = A.length; + + if (n <= 64) { + return conventionalMultiply(A, B); + } else { + int newSize = n / 2; + double[][] a11 = new double[newSize][newSize]; + double[][] a12 = new double[newSize][newSize]; + double[][] a21 = new double[newSize][newSize]; + double[][] a22 = new double[newSize][newSize]; + double[][] b11 = new double[newSize][newSize]; + double[][] b12 = new double[newSize][newSize]; + double[][] b21 = new double[newSize][newSize]; + double[][] b22 = new double[newSize][newSize]; + + divideMatrix(A, a11, 0, 0); + divideMatrix(A, a12, 0, newSize); + divideMatrix(A, a21, newSize, 0); + divideMatrix(A, a22, newSize, newSize); + divideMatrix(B, b11, 0, 0); + divideMatrix(B, b12, 0, newSize); + divideMatrix(B, b21, newSize, 0); + divideMatrix(B, b22, newSize, newSize); + + StrassenTask task1 = new StrassenTask(addMatrices(a11, a22), addMatrices(b11, b22)); + StrassenTask task2 = new StrassenTask(addMatrices(a21, a22), b11); + StrassenTask task3 = new StrassenTask(a11, subtractMatrices(b12, b22)); + StrassenTask task4 = new StrassenTask(a22, subtractMatrices(b21, b11)); + StrassenTask task5 = new StrassenTask(addMatrices(a11, a12), b22); + StrassenTask task6 = new StrassenTask(subtractMatrices(a21, a11), addMatrices(b11, b12)); + StrassenTask task7 = new StrassenTask(subtractMatrices(a12, a22), addMatrices(b21, b22)); + + invokeAll(task1, task2, task3, task4, task5, task6, task7); + + double[][] m1 = task1.join(); + double[][] m2 = task2.join(); + double[][] m3 = task3.join(); + double[][] m4 = task4.join(); + double[][] m5 = task5.join(); + double[][] m6 = task6.join(); + double[][] m7 = task7.join(); + + double[][] c11 = addMatrices(subtractMatrices(addMatrices(m1, m4), m5), m7); + double[][] c12 = addMatrices(m3, m5); + double[][] c21 = addMatrices(m2, m4); + double[][] c22 = addMatrices(subtractMatrices(addMatrices(m1, m3), m2), m6); + + double[][] result = new double[n][n]; + + combineMatrix(c11, result, 0, 0); + combineMatrix(c12, result, 0, newSize); + combineMatrix(c21, result, newSize, 0); + combineMatrix(c22, result, newSize, newSize); + + return result; + } + } + } +} diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java new file mode 100644 index 0000000..db6403b --- /dev/null +++ b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java @@ -0,0 +1,19 @@ +package de.edux.core.math.matrix.strassen; + +import de.edux.core.math.IMatrixArithmetic; +import java.util.concurrent.ForkJoinPool; + +public class StrassenParallelInplace implements IMatrixArithmetic { + private ForkJoinPool forkJoinPool = new ForkJoinPool(); + + @Override + public double[][] multiply(double[][] matrixA, double[][] matrixB) {} + + private int nextPowerOfTwo(int number) { + int power = 1; + while (power < number) { + power *= 2; + } + return power; + } +} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java new file mode 100644 index 0000000..bf42a2e --- /dev/null +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java @@ -0,0 +1,33 @@ +package de.edux.core.math.matrix.strassen; + +import static org.junit.jupiter.api.Assertions.*; + +import de.edux.core.math.IMatrixArithmetic; +import org.junit.jupiter.api.Test; + +class StrassenParallelInplaceTest { + + @Test + public void shouldMultiplyWithParallelStrassen() { + int matrixSize = 2000; + double[][] matrixA = new double[matrixSize][matrixSize]; + double[][] matrixB = new double[matrixSize][matrixSize]; + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + matrixA[i][j] = 1; + matrixB[i][j] = 1; + } + } + + IMatrixArithmetic strassenParallel = new StrassenParallelInplace(); + + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); + } + } + } +} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java new file mode 100644 index 0000000..a14a83d --- /dev/null +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java @@ -0,0 +1,33 @@ +package de.edux.core.math.matrix.strassen; + +import static org.junit.jupiter.api.Assertions.*; + +import de.edux.core.math.IMatrixArithmetic; +import org.junit.jupiter.api.Test; + +class StrassenParallelTest { + + @Test + public void shouldMultiplyWithParallelStrassen() { + int matrixSize = 2000; + double[][] matrixA = new double[matrixSize][matrixSize]; + double[][] matrixB = new double[matrixSize][matrixSize]; + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + matrixA[i][j] = 1; + matrixB[i][j] = 1; + } + } + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); + } + } + } +} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java new file mode 100644 index 0000000..9c73917 --- /dev/null +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java @@ -0,0 +1,40 @@ +package de.edux.core.math.matrix.strassen; + +import static org.junit.jupiter.api.Assertions.*; + +import de.edux.core.math.IMatrixArithmetic; +import org.junit.jupiter.api.Test; + +class StrassenTest { + + @Test + public void testLargeMatrixMultiplication() { + double[][] matrixA = new double[1000][1000]; + double[][] matrixB = new double[1000][1000]; + + for (int i = 0; i < 1000; i++) { + for (int j = 0; j < 1000; j++) { + matrixA[i][j] = 1; + matrixB[i][j] = 1; + } + } + + IMatrixArithmetic strassen = new Strassen(); + + double[][] result = strassen.multiply(matrixA, matrixB); + for (int i = 0; i < 1000; i++) { + for (int j = 0; j < 1000; j++) { + assertEquals(1000, result[i][j], "Result on [" + i + "][" + j + "] not correct."); + } + } + } + + private void printMatrix(double[][] result) { + for (int i = 0; i < result.length; i++) { + for (int j = 0; j < result.length; j++) { + System.out.print(result[i][j] + " "); + } + System.out.println(); + } + } +} From e92014911df8c500f8ecfb80466126a7d530a7bd Mon Sep 17 00:00:00 2001 From: Samuel Abramov Date: Mon, 6 Nov 2023 02:04:15 +0100 Subject: [PATCH 3/3] feat():Implement Matrix Multiplication Using Strassen Algorithm #75 --- gradle.properties | 1 + .../strassen/StrassenMultiplicatIMatrix.java | 10 -- .../matrix/strassen/StrassenParallel.java | 2 +- .../strassen/StrassenParallelInplace.java | 19 --- .../strassen/StrassenParallelInplaceTest.java | 33 ----- .../matrix/strassen/StrassenParallelTest.java | 101 ++++++++++++++- .../math/matrix/strassen/StrassenTest.java | 122 +++++++++++++++--- 7 files changed, 209 insertions(+), 79 deletions(-) delete mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java delete mode 100644 lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java delete mode 100644 lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java diff --git a/gradle.properties b/gradle.properties index 4f996f1..950d4eb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,3 @@ org.gradle.parallel=true org.gradle.caching=true +org.gradle.jvmargs=-Xmx32g -Xms8g diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java deleted file mode 100644 index 9166756..0000000 --- a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenMultiplicatIMatrix.java +++ /dev/null @@ -1,10 +0,0 @@ -package de.edux.core.math.matrix.strassen; - -import de.edux.core.math.IMatrixArithmetic; - -public class StrassenMultiplicatIMatrix implements IMatrixArithmetic { - @Override - public double[][] multiply(double[][] matrixA, double[][] matrixB) { - return new double[0][]; - } -} diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java index f3b5490..2a40dfa 100644 --- a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java +++ b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java @@ -103,7 +103,7 @@ private class StrassenTask extends RecursiveTask { protected double[][] compute() { int n = A.length; - if (n <= 64) { + if (n <= 512) { return conventionalMultiply(A, B); } else { int newSize = n / 2; diff --git a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java b/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java deleted file mode 100644 index db6403b..0000000 --- a/lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallelInplace.java +++ /dev/null @@ -1,19 +0,0 @@ -package de.edux.core.math.matrix.strassen; - -import de.edux.core.math.IMatrixArithmetic; -import java.util.concurrent.ForkJoinPool; - -public class StrassenParallelInplace implements IMatrixArithmetic { - private ForkJoinPool forkJoinPool = new ForkJoinPool(); - - @Override - public double[][] multiply(double[][] matrixA, double[][] matrixB) {} - - private int nextPowerOfTwo(int number) { - int power = 1; - while (power < number) { - power *= 2; - } - return power; - } -} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java deleted file mode 100644 index bf42a2e..0000000 --- a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelInplaceTest.java +++ /dev/null @@ -1,33 +0,0 @@ -package de.edux.core.math.matrix.strassen; - -import static org.junit.jupiter.api.Assertions.*; - -import de.edux.core.math.IMatrixArithmetic; -import org.junit.jupiter.api.Test; - -class StrassenParallelInplaceTest { - - @Test - public void shouldMultiplyWithParallelStrassen() { - int matrixSize = 2000; - double[][] matrixA = new double[matrixSize][matrixSize]; - double[][] matrixB = new double[matrixSize][matrixSize]; - - for (int i = 0; i < matrixSize; i++) { - for (int j = 0; j < matrixSize; j++) { - matrixA[i][j] = 1; - matrixB[i][j] = 1; - } - } - - IMatrixArithmetic strassenParallel = new StrassenParallelInplace(); - - double[][] result = strassenParallel.multiply(matrixA, matrixB); - - for (int i = 0; i < matrixSize; i++) { - for (int j = 0; j < matrixSize; j++) { - assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); - } - } - } -} diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java index a14a83d..5fe6f77 100644 --- a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenParallelTest.java @@ -9,7 +9,7 @@ class StrassenParallelTest { @Test public void shouldMultiplyWithParallelStrassen() { - int matrixSize = 2000; + int matrixSize = 2048; double[][] matrixA = new double[matrixSize][matrixSize]; double[][] matrixB = new double[matrixSize][matrixSize]; @@ -30,4 +30,103 @@ public void shouldMultiplyWithParallelStrassen() { } } } + + @Test + public void shouldMultiply8x8MatricesCorrectly() { + double[][] matrixA = { + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 4, 3, 2, 1}, + {2, 3, 4, 5, 6, 7, 8, 9}, + {9, 8, 7, 6, 5, 4, 3, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {1, 3, 5, 7, 9, 11, 13, 15}, + {15, 13, 11, 9, 7, 5, 3, 1} + }; + double[][] matrixB = { + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1} + }; + double[][] expected = { + {6, 8, 10, 12, 6, 8, 10, 12}, + {12, 10, 8, 6, 12, 10, 8, 6}, + {8, 10, 12, 14, 8, 10, 12, 14}, + {14, 12, 10, 8, 14, 12, 10, 8}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {4, 4, 4, 4, 4, 4, 4, 4}, + {10, 14, 18, 22, 10, 14, 18, 22}, + {22, 18, 14, 10, 22, 18, 14, 10} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals( + expected, result, "The 8x8 matrix multiplication did not yield the correct result."); + } + + @Test + public void shouldMultiplyWithZeroMatrix() { + double[][] matrixA = { + {0, 0}, + {0, 0} + }; + double[][] matrixB = { + {1, 2}, + {3, 4} + }; + double[][] expected = { + {0, 0}, + {0, 0} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } + + @Test + public void shouldMultiplyWithIdentityMatrix() { + double[][] matrixA = { + {1, 0}, + {0, 1} + }; + double[][] matrixB = { + {5, 6}, + {7, 8} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals(matrixB, result); + } + + @Test + public void shouldMultiplySmallMatricesCorrectly() { + double[][] matrixA = { + {1, 2}, + {3, 4} + }; + double[][] matrixB = { + {2, 0}, + {1, 2} + }; + double[][] expected = { + {4, 4}, + {10, 8} + }; + + IMatrixArithmetic strassenParallel = new StrassenParallel(); + double[][] result = strassenParallel.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } } diff --git a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java index 9c73917..1f8c2cd 100644 --- a/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java +++ b/lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java @@ -8,12 +8,13 @@ class StrassenTest { @Test - public void testLargeMatrixMultiplication() { - double[][] matrixA = new double[1000][1000]; - double[][] matrixB = new double[1000][1000]; + public void shouldMultiplyWithParallelStrassen() { + int matrixSize = 128; + double[][] matrixA = new double[matrixSize][matrixSize]; + double[][] matrixB = new double[matrixSize][matrixSize]; - for (int i = 0; i < 1000; i++) { - for (int j = 0; j < 1000; j++) { + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { matrixA[i][j] = 1; matrixB[i][j] = 1; } @@ -22,19 +23,110 @@ public void testLargeMatrixMultiplication() { IMatrixArithmetic strassen = new Strassen(); double[][] result = strassen.multiply(matrixA, matrixB); - for (int i = 0; i < 1000; i++) { - for (int j = 0; j < 1000; j++) { - assertEquals(1000, result[i][j], "Result on [" + i + "][" + j + "] not correct."); + + for (int i = 0; i < matrixSize; i++) { + for (int j = 0; j < matrixSize; j++) { + assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); } } } - private void printMatrix(double[][] result) { - for (int i = 0; i < result.length; i++) { - for (int j = 0; j < result.length; j++) { - System.out.print(result[i][j] + " "); - } - System.out.println(); - } + @Test + public void shouldMultiply8x8MatricesCorrectly() { + double[][] matrixA = { + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 4, 3, 2, 1}, + {2, 3, 4, 5, 6, 7, 8, 9}, + {9, 8, 7, 6, 5, 4, 3, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {1, 3, 5, 7, 9, 11, 13, 15}, + {15, 13, 11, 9, 7, 5, 3, 1} + }; + double[][] matrixB = { + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 1, 0, 0, 0, 1} + }; + double[][] expected = { + {6, 8, 10, 12, 6, 8, 10, 12}, + {12, 10, 8, 6, 12, 10, 8, 6}, + {8, 10, 12, 14, 8, 10, 12, 14}, + {14, 12, 10, 8, 14, 12, 10, 8}, + {2, 2, 2, 2, 2, 2, 2, 2}, + {4, 4, 4, 4, 4, 4, 4, 4}, + {10, 14, 18, 22, 10, 14, 18, 22}, + {22, 18, 14, 10, 22, 18, 14, 10} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals( + expected, result, "The 8x8 matrix multiplication did not yield the correct result."); + } + + @Test + public void shouldMultiplyWithZeroMatrix() { + double[][] matrixA = { + {0, 0}, + {0, 0} + }; + double[][] matrixB = { + {1, 2}, + {3, 4} + }; + double[][] expected = { + {0, 0}, + {0, 0} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); + } + + @Test + public void shouldMultiplyWithIdentityMatrix() { + double[][] matrixA = { + {1, 0}, + {0, 1} + }; + double[][] matrixB = { + {5, 6}, + {7, 8} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals(matrixB, result); + } + + @Test + public void shouldMultiplySmallMatricesCorrectly() { + double[][] matrixA = { + {1, 2}, + {3, 4} + }; + double[][] matrixB = { + {2, 0}, + {1, 2} + }; + double[][] expected = { + {4, 4}, + {10, 8} + }; + + IMatrixArithmetic strassen = new Strassen(); + double[][] result = strassen.multiply(matrixA, matrixB); + + assertArrayEquals(expected, result); } }