-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #93 from Samyssmile/matrix
Implement Matrix Multiplication Using Strassen Algorithm #75
- Loading branch information
Showing
15 changed files
with
575 additions
and
465 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
org.gradle.parallel=true | ||
org.gradle.caching=true | ||
org.gradle.jvmargs=-Xmx32g -Xms8g |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package de.edux.core.math; | ||
|
||
public interface IMatrixArithmetic { | ||
double[][] multiply(double[][] matrixA, double[][] matrixB); | ||
} |
10 changes: 10 additions & 0 deletions
10
lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package de.edux.core.math.matrix.classic; | ||
|
||
import de.edux.core.math.IMatrixArithmetic; | ||
|
||
public class MatrixArithmetic implements IMatrixArithmetic { | ||
@Override | ||
public double[][] multiply(double[][] matrixA, double[][] matrixB) { | ||
return new double[0][]; | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
lib/src/main/java/de/edux/core/math/matrix/parallel/MatrixParallelArithmetic.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package de.edux.core.math.matrix.parallel; | ||
|
||
import de.edux.core.math.IMatrixArithmetic; | ||
|
||
public class MatrixParallelArithmetic implements IMatrixArithmetic { | ||
@Override | ||
public double[][] multiply(double[][] matrixA, double[][] matrixB) { | ||
return new double[0][]; | ||
} | ||
} |
123 changes: 123 additions & 0 deletions
123
lib/src/main/java/de/edux/core/math/matrix/strassen/Strassen.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package de.edux.core.math.matrix.strassen; | ||
|
||
import de.edux.core.math.IMatrixArithmetic; | ||
|
||
public class Strassen implements IMatrixArithmetic { | ||
|
||
@Override | ||
public double[][] multiply(double[][] matrixA, double[][] matrixB) { | ||
int n = matrixA.length; | ||
int m = nextPowerOfTwo(n); | ||
double[][] extendedMatrixA = new double[m][m]; | ||
double[][] extendedMatrixB = new double[m][m]; | ||
|
||
for (int i = 0; i < n; i++) { | ||
System.arraycopy(matrixA[i], 0, extendedMatrixA[i], 0, matrixA[i].length); | ||
System.arraycopy(matrixB[i], 0, extendedMatrixB[i], 0, matrixB[i].length); | ||
} | ||
|
||
double[][] extendedResult = strassen(extendedMatrixA, extendedMatrixB); | ||
|
||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
System.arraycopy(extendedResult[i], 0, result[i], 0, n); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
private double[][] strassen(double[][] A, double[][] B) { | ||
int n = A.length; | ||
|
||
double[][] result = new double[n][n]; | ||
|
||
if (n == 1) { | ||
result[0][0] = A[0][0] * B[0][0]; | ||
} else { | ||
int newSize = n / 2; | ||
double[][] a11 = new double[newSize][newSize]; | ||
double[][] a12 = new double[newSize][newSize]; | ||
double[][] a21 = new double[newSize][newSize]; | ||
double[][] a22 = new double[newSize][newSize]; | ||
double[][] b11 = new double[newSize][newSize]; | ||
double[][] b12 = new double[newSize][newSize]; | ||
double[][] b21 = new double[newSize][newSize]; | ||
double[][] b22 = new double[newSize][newSize]; | ||
|
||
divideMatrix(A, a11, 0, 0); | ||
divideMatrix(A, a12, 0, newSize); | ||
divideMatrix(A, a21, newSize, 0); | ||
divideMatrix(A, a22, newSize, newSize); | ||
divideMatrix(B, b11, 0, 0); | ||
divideMatrix(B, b12, 0, newSize); | ||
divideMatrix(B, b21, newSize, 0); | ||
divideMatrix(B, b22, newSize, newSize); | ||
|
||
double[][] m1 = strassen(addMatrices(a11, a22), addMatrices(b11, b22)); | ||
double[][] m2 = strassen(addMatrices(a21, a22), b11); | ||
double[][] m3 = strassen(a11, subtractMatrices(b12, b22)); | ||
double[][] m4 = strassen(a22, subtractMatrices(b21, b11)); | ||
double[][] m5 = strassen(addMatrices(a11, a12), b22); | ||
double[][] m6 = strassen(subtractMatrices(a21, a11), addMatrices(b11, b12)); | ||
double[][] m7 = strassen(subtractMatrices(a12, a22), addMatrices(b21, b22)); | ||
|
||
double[][] c11 = addMatrices(subtractMatrices(addMatrices(m1, m4), m5), m7); | ||
double[][] c12 = addMatrices(m3, m5); | ||
double[][] c21 = addMatrices(m2, m4); | ||
double[][] c22 = addMatrices(subtractMatrices(addMatrices(m1, m3), m2), m6); | ||
|
||
combineMatrix(c11, result, 0, 0); | ||
combineMatrix(c12, result, 0, newSize); | ||
combineMatrix(c21, result, newSize, 0); | ||
combineMatrix(c22, result, newSize, newSize); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
private void divideMatrix(double[][] parent, double[][] child, int iB, int jB) { | ||
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { | ||
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { | ||
child[i1][j1] = parent[i2][j2]; | ||
} | ||
} | ||
} | ||
|
||
private void combineMatrix(double[][] child, double[][] parent, int iB, int jB) { | ||
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { | ||
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { | ||
parent[i2][j2] = child[i1][j1]; | ||
} | ||
} | ||
} | ||
|
||
private double[][] addMatrices(double[][] a, double[][] b) { | ||
int n = a.length; | ||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
for (int j = 0; j < n; j++) { | ||
result[i][j] = a[i][j] + b[i][j]; | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
private double[][] subtractMatrices(double[][] a, double[][] b) { | ||
int n = a.length; | ||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
for (int j = 0; j < n; j++) { | ||
result[i][j] = a[i][j] - b[i][j]; | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
private int nextPowerOfTwo(int number) { | ||
int power = 1; | ||
while (power < number) { | ||
power *= 2; | ||
} | ||
return power; | ||
} | ||
} |
162 changes: 162 additions & 0 deletions
162
lib/src/main/java/de/edux/core/math/matrix/strassen/StrassenParallel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
package de.edux.core.math.matrix.strassen; | ||
|
||
import de.edux.core.math.IMatrixArithmetic; | ||
import java.util.concurrent.ForkJoinPool; | ||
import java.util.concurrent.RecursiveTask; | ||
|
||
public class StrassenParallel implements IMatrixArithmetic { | ||
|
||
private ForkJoinPool forkJoinPool = new ForkJoinPool(4); | ||
|
||
@Override | ||
public double[][] multiply(double[][] matrixA, double[][] matrixB) { | ||
int n = matrixA.length; | ||
int m = nextPowerOfTwo(n); | ||
double[][] extendedMatrixA = new double[m][m]; | ||
double[][] extendedMatrixB = new double[m][m]; | ||
|
||
for (int i = 0; i < n; i++) { | ||
System.arraycopy(matrixA[i], 0, extendedMatrixA[i], 0, matrixA[i].length); | ||
System.arraycopy(matrixB[i], 0, extendedMatrixB[i], 0, matrixB[i].length); | ||
} | ||
|
||
double[][] extendedResult = | ||
forkJoinPool.invoke(new StrassenTask(extendedMatrixA, extendedMatrixB)); | ||
|
||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
System.arraycopy(extendedResult[i], 0, result[i], 0, n); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
private double[][] conventionalMultiply(double[][] A, double[][] B) { | ||
int n = A.length; | ||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
for (int j = 0; j < n; j++) { | ||
for (int k = 0; k < n; k++) { | ||
result[i][j] += A[i][k] * B[k][j]; | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
private int nextPowerOfTwo(int number) { | ||
int power = 1; | ||
while (power < number) { | ||
power *= 2; | ||
} | ||
return power; | ||
} | ||
|
||
private void divideMatrix(double[][] parent, double[][] child, int iB, int jB) { | ||
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { | ||
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { | ||
child[i1][j1] = parent[i2][j2]; | ||
} | ||
} | ||
} | ||
|
||
private void combineMatrix(double[][] child, double[][] parent, int iB, int jB) { | ||
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { | ||
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { | ||
parent[i2][j2] = child[i1][j1]; | ||
} | ||
} | ||
} | ||
|
||
private double[][] addMatrices(double[][] a, double[][] b) { | ||
int n = a.length; | ||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
for (int j = 0; j < n; j++) { | ||
result[i][j] = a[i][j] + b[i][j]; | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
private double[][] subtractMatrices(double[][] a, double[][] b) { | ||
int n = a.length; | ||
double[][] result = new double[n][n]; | ||
for (int i = 0; i < n; i++) { | ||
for (int j = 0; j < n; j++) { | ||
result[i][j] = a[i][j] - b[i][j]; | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
private class StrassenTask extends RecursiveTask<double[][]> { | ||
private double[][] A; | ||
private double[][] B; | ||
|
||
StrassenTask(double[][] A, double[][] B) { | ||
this.A = A; | ||
this.B = B; | ||
} | ||
|
||
@Override | ||
protected double[][] compute() { | ||
int n = A.length; | ||
|
||
if (n <= 512) { | ||
return conventionalMultiply(A, B); | ||
} else { | ||
int newSize = n / 2; | ||
double[][] a11 = new double[newSize][newSize]; | ||
double[][] a12 = new double[newSize][newSize]; | ||
double[][] a21 = new double[newSize][newSize]; | ||
double[][] a22 = new double[newSize][newSize]; | ||
double[][] b11 = new double[newSize][newSize]; | ||
double[][] b12 = new double[newSize][newSize]; | ||
double[][] b21 = new double[newSize][newSize]; | ||
double[][] b22 = new double[newSize][newSize]; | ||
|
||
divideMatrix(A, a11, 0, 0); | ||
divideMatrix(A, a12, 0, newSize); | ||
divideMatrix(A, a21, newSize, 0); | ||
divideMatrix(A, a22, newSize, newSize); | ||
divideMatrix(B, b11, 0, 0); | ||
divideMatrix(B, b12, 0, newSize); | ||
divideMatrix(B, b21, newSize, 0); | ||
divideMatrix(B, b22, newSize, newSize); | ||
|
||
StrassenTask task1 = new StrassenTask(addMatrices(a11, a22), addMatrices(b11, b22)); | ||
StrassenTask task2 = new StrassenTask(addMatrices(a21, a22), b11); | ||
StrassenTask task3 = new StrassenTask(a11, subtractMatrices(b12, b22)); | ||
StrassenTask task4 = new StrassenTask(a22, subtractMatrices(b21, b11)); | ||
StrassenTask task5 = new StrassenTask(addMatrices(a11, a12), b22); | ||
StrassenTask task6 = new StrassenTask(subtractMatrices(a21, a11), addMatrices(b11, b12)); | ||
StrassenTask task7 = new StrassenTask(subtractMatrices(a12, a22), addMatrices(b21, b22)); | ||
|
||
invokeAll(task1, task2, task3, task4, task5, task6, task7); | ||
|
||
double[][] m1 = task1.join(); | ||
double[][] m2 = task2.join(); | ||
double[][] m3 = task3.join(); | ||
double[][] m4 = task4.join(); | ||
double[][] m5 = task5.join(); | ||
double[][] m6 = task6.join(); | ||
double[][] m7 = task7.join(); | ||
|
||
double[][] c11 = addMatrices(subtractMatrices(addMatrices(m1, m4), m5), m7); | ||
double[][] c12 = addMatrices(m3, m5); | ||
double[][] c21 = addMatrices(m2, m4); | ||
double[][] c22 = addMatrices(subtractMatrices(addMatrices(m1, m3), m2), m6); | ||
|
||
double[][] result = new double[n][n]; | ||
|
||
combineMatrix(c11, result, 0, 0); | ||
combineMatrix(c12, result, 0, newSize); | ||
combineMatrix(c21, result, newSize, 0); | ||
combineMatrix(c22, result, newSize, newSize); | ||
|
||
return result; | ||
} | ||
} | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.