Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Matrix Multiplication Using Strassen Algorithm #75 #93

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
org.gradle.parallel=true
org.gradle.caching=true
org.gradle.jvmargs=-Xmx32g -Xms8g
5 changes: 5 additions & 0 deletions lib/src/main/java/de/edux/core/math/IMatrixArithmetic.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package de.edux.core.math;

public interface IMatrixArithmetic {
double[][] multiply(double[][] matrixA, double[][] matrixB);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package de.edux.core.math.matrix.classic;

import de.edux.core.math.IMatrixArithmetic;

public class MatrixArithmetic implements IMatrixArithmetic {
@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
return new double[0][];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package de.edux.core.math.matrix.parallel;

import de.edux.core.math.IMatrixArithmetic;

public class MatrixParallelArithmetic implements IMatrixArithmetic {
@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
return new double[0][];
}
}
123 changes: 123 additions & 0 deletions lib/src/main/java/de/edux/core/math/matrix/strassen/Strassen.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package de.edux.core.math.matrix.strassen;

import de.edux.core.math.IMatrixArithmetic;

public class Strassen implements IMatrixArithmetic {

@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
int n = matrixA.length;
int m = nextPowerOfTwo(n);
double[][] extendedMatrixA = new double[m][m];
double[][] extendedMatrixB = new double[m][m];

for (int i = 0; i < n; i++) {
System.arraycopy(matrixA[i], 0, extendedMatrixA[i], 0, matrixA[i].length);
System.arraycopy(matrixB[i], 0, extendedMatrixB[i], 0, matrixB[i].length);
}

double[][] extendedResult = strassen(extendedMatrixA, extendedMatrixB);

double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
System.arraycopy(extendedResult[i], 0, result[i], 0, n);
}

return result;
}

private double[][] strassen(double[][] A, double[][] B) {
int n = A.length;

double[][] result = new double[n][n];

if (n == 1) {
result[0][0] = A[0][0] * B[0][0];
} else {
int newSize = n / 2;
double[][] a11 = new double[newSize][newSize];
double[][] a12 = new double[newSize][newSize];
double[][] a21 = new double[newSize][newSize];
double[][] a22 = new double[newSize][newSize];
double[][] b11 = new double[newSize][newSize];
double[][] b12 = new double[newSize][newSize];
double[][] b21 = new double[newSize][newSize];
double[][] b22 = new double[newSize][newSize];

divideMatrix(A, a11, 0, 0);
divideMatrix(A, a12, 0, newSize);
divideMatrix(A, a21, newSize, 0);
divideMatrix(A, a22, newSize, newSize);
divideMatrix(B, b11, 0, 0);
divideMatrix(B, b12, 0, newSize);
divideMatrix(B, b21, newSize, 0);
divideMatrix(B, b22, newSize, newSize);

double[][] m1 = strassen(addMatrices(a11, a22), addMatrices(b11, b22));
double[][] m2 = strassen(addMatrices(a21, a22), b11);
double[][] m3 = strassen(a11, subtractMatrices(b12, b22));
double[][] m4 = strassen(a22, subtractMatrices(b21, b11));
double[][] m5 = strassen(addMatrices(a11, a12), b22);
double[][] m6 = strassen(subtractMatrices(a21, a11), addMatrices(b11, b12));
double[][] m7 = strassen(subtractMatrices(a12, a22), addMatrices(b21, b22));

double[][] c11 = addMatrices(subtractMatrices(addMatrices(m1, m4), m5), m7);
double[][] c12 = addMatrices(m3, m5);
double[][] c21 = addMatrices(m2, m4);
double[][] c22 = addMatrices(subtractMatrices(addMatrices(m1, m3), m2), m6);

combineMatrix(c11, result, 0, 0);
combineMatrix(c12, result, 0, newSize);
combineMatrix(c21, result, newSize, 0);
combineMatrix(c22, result, newSize, newSize);
}

return result;
}

private void divideMatrix(double[][] parent, double[][] child, int iB, int jB) {
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) {
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) {
child[i1][j1] = parent[i2][j2];
}
}
}

private void combineMatrix(double[][] child, double[][] parent, int iB, int jB) {
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) {
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) {
parent[i2][j2] = child[i1][j1];
}
}
}

private double[][] addMatrices(double[][] a, double[][] b) {
int n = a.length;
double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result[i][j] = a[i][j] + b[i][j];
}
}
return result;
}

private double[][] subtractMatrices(double[][] a, double[][] b) {
int n = a.length;
double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result[i][j] = a[i][j] - b[i][j];
}
}
return result;
}

private int nextPowerOfTwo(int number) {
int power = 1;
while (power < number) {
power *= 2;
}
return power;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package de.edux.core.math.matrix.strassen;

import de.edux.core.math.IMatrixArithmetic;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class StrassenParallel implements IMatrixArithmetic {

private ForkJoinPool forkJoinPool = new ForkJoinPool(4);

@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
int n = matrixA.length;
int m = nextPowerOfTwo(n);
double[][] extendedMatrixA = new double[m][m];
double[][] extendedMatrixB = new double[m][m];

for (int i = 0; i < n; i++) {
System.arraycopy(matrixA[i], 0, extendedMatrixA[i], 0, matrixA[i].length);
System.arraycopy(matrixB[i], 0, extendedMatrixB[i], 0, matrixB[i].length);
}

double[][] extendedResult =
forkJoinPool.invoke(new StrassenTask(extendedMatrixA, extendedMatrixB));

double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
System.arraycopy(extendedResult[i], 0, result[i], 0, n);
}

return result;
}

private double[][] conventionalMultiply(double[][] A, double[][] B) {
int n = A.length;
double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
result[i][j] += A[i][k] * B[k][j];
}
}
}
return result;
}

private int nextPowerOfTwo(int number) {
int power = 1;
while (power < number) {
power *= 2;
}
return power;
}

private void divideMatrix(double[][] parent, double[][] child, int iB, int jB) {
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) {
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) {
child[i1][j1] = parent[i2][j2];
}
}
}

private void combineMatrix(double[][] child, double[][] parent, int iB, int jB) {
for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) {
for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) {
parent[i2][j2] = child[i1][j1];
}
}
}

private double[][] addMatrices(double[][] a, double[][] b) {
int n = a.length;
double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result[i][j] = a[i][j] + b[i][j];
}
}
return result;
}

private double[][] subtractMatrices(double[][] a, double[][] b) {
int n = a.length;
double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result[i][j] = a[i][j] - b[i][j];
}
}
return result;
}

private class StrassenTask extends RecursiveTask<double[][]> {
private double[][] A;
private double[][] B;

StrassenTask(double[][] A, double[][] B) {
this.A = A;
this.B = B;
}

@Override
protected double[][] compute() {
int n = A.length;

if (n <= 512) {
return conventionalMultiply(A, B);
} else {
int newSize = n / 2;
double[][] a11 = new double[newSize][newSize];
double[][] a12 = new double[newSize][newSize];
double[][] a21 = new double[newSize][newSize];
double[][] a22 = new double[newSize][newSize];
double[][] b11 = new double[newSize][newSize];
double[][] b12 = new double[newSize][newSize];
double[][] b21 = new double[newSize][newSize];
double[][] b22 = new double[newSize][newSize];

divideMatrix(A, a11, 0, 0);
divideMatrix(A, a12, 0, newSize);
divideMatrix(A, a21, newSize, 0);
divideMatrix(A, a22, newSize, newSize);
divideMatrix(B, b11, 0, 0);
divideMatrix(B, b12, 0, newSize);
divideMatrix(B, b21, newSize, 0);
divideMatrix(B, b22, newSize, newSize);

StrassenTask task1 = new StrassenTask(addMatrices(a11, a22), addMatrices(b11, b22));
StrassenTask task2 = new StrassenTask(addMatrices(a21, a22), b11);
StrassenTask task3 = new StrassenTask(a11, subtractMatrices(b12, b22));
StrassenTask task4 = new StrassenTask(a22, subtractMatrices(b21, b11));
StrassenTask task5 = new StrassenTask(addMatrices(a11, a12), b22);
StrassenTask task6 = new StrassenTask(subtractMatrices(a21, a11), addMatrices(b11, b12));
StrassenTask task7 = new StrassenTask(subtractMatrices(a12, a22), addMatrices(b21, b22));

invokeAll(task1, task2, task3, task4, task5, task6, task7);

double[][] m1 = task1.join();
double[][] m2 = task2.join();
double[][] m3 = task3.join();
double[][] m4 = task4.join();
double[][] m5 = task5.join();
double[][] m6 = task6.join();
double[][] m7 = task7.join();

double[][] c11 = addMatrices(subtractMatrices(addMatrices(m1, m4), m5), m7);
double[][] c12 = addMatrices(m3, m5);
double[][] c21 = addMatrices(m2, m4);
double[][] c22 = addMatrices(subtractMatrices(addMatrices(m1, m3), m2), m6);

double[][] result = new double[n][n];

combineMatrix(c11, result, 0, 0);
combineMatrix(c12, result, 0, newSize);
combineMatrix(c21, result, newSize, 0);
combineMatrix(c22, result, newSize, newSize);

return result;
}
}
}
}
12 changes: 0 additions & 12 deletions lib/src/main/java/de/edux/math/Entity.java

This file was deleted.

15 changes: 0 additions & 15 deletions lib/src/main/java/de/edux/math/MathUtil.java

This file was deleted.

Loading