Skip to content

Commit

Permalink
feat():Implement Matrix Multiplication Using Strassen Algorithm #75
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Abramov committed Nov 6, 2023
1 parent 193845e commit e920149
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 79 deletions.
1 change: 1 addition & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
org.gradle.parallel=true
org.gradle.caching=true
org.gradle.jvmargs=-Xmx32g -Xms8g

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private class StrassenTask extends RecursiveTask<double[][]> {
protected double[][] compute() {
int n = A.length;

if (n <= 64) {
if (n <= 512) {
return conventionalMultiply(A, B);
} else {
int newSize = n / 2;
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class StrassenParallelTest {

@Test
public void shouldMultiplyWithParallelStrassen() {
int matrixSize = 2000;
int matrixSize = 2048;
double[][] matrixA = new double[matrixSize][matrixSize];
double[][] matrixB = new double[matrixSize][matrixSize];

Expand All @@ -30,4 +30,103 @@ public void shouldMultiplyWithParallelStrassen() {
}
}
}

@Test
public void shouldMultiply8x8MatricesCorrectly() {
double[][] matrixA = {
{1, 2, 3, 4, 5, 6, 7, 8},
{8, 7, 6, 5, 4, 3, 2, 1},
{2, 3, 4, 5, 6, 7, 8, 9},
{9, 8, 7, 6, 5, 4, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{1, 3, 5, 7, 9, 11, 13, 15},
{15, 13, 11, 9, 7, 5, 3, 1}
};
double[][] matrixB = {
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1},
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1}
};
double[][] expected = {
{6, 8, 10, 12, 6, 8, 10, 12},
{12, 10, 8, 6, 12, 10, 8, 6},
{8, 10, 12, 14, 8, 10, 12, 14},
{14, 12, 10, 8, 14, 12, 10, 8},
{2, 2, 2, 2, 2, 2, 2, 2},
{4, 4, 4, 4, 4, 4, 4, 4},
{10, 14, 18, 22, 10, 14, 18, 22},
{22, 18, 14, 10, 22, 18, 14, 10}
};

IMatrixArithmetic strassenParallel = new StrassenParallel();
double[][] result = strassenParallel.multiply(matrixA, matrixB);

assertArrayEquals(
expected, result, "The 8x8 matrix multiplication did not yield the correct result.");
}

@Test
public void shouldMultiplyWithZeroMatrix() {
double[][] matrixA = {
{0, 0},
{0, 0}
};
double[][] matrixB = {
{1, 2},
{3, 4}
};
double[][] expected = {
{0, 0},
{0, 0}
};

IMatrixArithmetic strassenParallel = new StrassenParallel();
double[][] result = strassenParallel.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}

@Test
public void shouldMultiplyWithIdentityMatrix() {
double[][] matrixA = {
{1, 0},
{0, 1}
};
double[][] matrixB = {
{5, 6},
{7, 8}
};

IMatrixArithmetic strassenParallel = new StrassenParallel();
double[][] result = strassenParallel.multiply(matrixA, matrixB);

assertArrayEquals(matrixB, result);
}

@Test
public void shouldMultiplySmallMatricesCorrectly() {
double[][] matrixA = {
{1, 2},
{3, 4}
};
double[][] matrixB = {
{2, 0},
{1, 2}
};
double[][] expected = {
{4, 4},
{10, 8}
};

IMatrixArithmetic strassenParallel = new StrassenParallel();
double[][] result = strassenParallel.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}
}
122 changes: 107 additions & 15 deletions lib/src/test/java/de/edux/core/math/matrix/strassen/StrassenTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
class StrassenTest {

@Test
public void testLargeMatrixMultiplication() {
double[][] matrixA = new double[1000][1000];
double[][] matrixB = new double[1000][1000];
public void shouldMultiplyWithParallelStrassen() {
int matrixSize = 128;
double[][] matrixA = new double[matrixSize][matrixSize];
double[][] matrixB = new double[matrixSize][matrixSize];

for (int i = 0; i < 1000; i++) {
for (int j = 0; j < 1000; j++) {
for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
matrixA[i][j] = 1;
matrixB[i][j] = 1;
}
Expand All @@ -22,19 +23,110 @@ public void testLargeMatrixMultiplication() {
IMatrixArithmetic strassen = new Strassen();

double[][] result = strassen.multiply(matrixA, matrixB);
for (int i = 0; i < 1000; i++) {
for (int j = 0; j < 1000; j++) {
assertEquals(1000, result[i][j], "Result on [" + i + "][" + j + "] not correct.");

for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct.");
}
}
}

private void printMatrix(double[][] result) {
for (int i = 0; i < result.length; i++) {
for (int j = 0; j < result.length; j++) {
System.out.print(result[i][j] + " ");
}
System.out.println();
}
@Test
public void shouldMultiply8x8MatricesCorrectly() {
double[][] matrixA = {
{1, 2, 3, 4, 5, 6, 7, 8},
{8, 7, 6, 5, 4, 3, 2, 1},
{2, 3, 4, 5, 6, 7, 8, 9},
{9, 8, 7, 6, 5, 4, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{1, 3, 5, 7, 9, 11, 13, 15},
{15, 13, 11, 9, 7, 5, 3, 1}
};
double[][] matrixB = {
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1},
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1}
};
double[][] expected = {
{6, 8, 10, 12, 6, 8, 10, 12},
{12, 10, 8, 6, 12, 10, 8, 6},
{8, 10, 12, 14, 8, 10, 12, 14},
{14, 12, 10, 8, 14, 12, 10, 8},
{2, 2, 2, 2, 2, 2, 2, 2},
{4, 4, 4, 4, 4, 4, 4, 4},
{10, 14, 18, 22, 10, 14, 18, 22},
{22, 18, 14, 10, 22, 18, 14, 10}
};

IMatrixArithmetic strassen = new Strassen();
double[][] result = strassen.multiply(matrixA, matrixB);

assertArrayEquals(
expected, result, "The 8x8 matrix multiplication did not yield the correct result.");
}

@Test
public void shouldMultiplyWithZeroMatrix() {
double[][] matrixA = {
{0, 0},
{0, 0}
};
double[][] matrixB = {
{1, 2},
{3, 4}
};
double[][] expected = {
{0, 0},
{0, 0}
};

IMatrixArithmetic strassen = new Strassen();
double[][] result = strassen.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}

@Test
public void shouldMultiplyWithIdentityMatrix() {
double[][] matrixA = {
{1, 0},
{0, 1}
};
double[][] matrixB = {
{5, 6},
{7, 8}
};

IMatrixArithmetic strassen = new Strassen();
double[][] result = strassen.multiply(matrixA, matrixB);

assertArrayEquals(matrixB, result);
}

@Test
public void shouldMultiplySmallMatricesCorrectly() {
double[][] matrixA = {
{1, 2},
{3, 4}
};
double[][] matrixB = {
{2, 0},
{1, 2}
};
double[][] expected = {
{4, 4},
{10, 8}
};

IMatrixArithmetic strassen = new Strassen();
double[][] result = strassen.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}
}

0 comments on commit e920149

Please sign in to comment.