-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #94 from Samyssmile/testsadded
tests(): Implement Multi-Threaded Matrix Multiplication #74
- Loading branch information
Showing
2 changed files
with
191 additions
and
1 deletion.
There are no files selected for viewing
36 changes: 35 additions & 1 deletion
36
lib/src/main/java/de/edux/core/math/matrix/classic/MatrixArithmetic.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,44 @@ | ||
package de.edux.core.math.matrix.classic; | ||
|
||
import de.edux.core.math.IMatrixArithmetic; | ||
import java.util.Arrays; | ||
|
||
public class MatrixArithmetic implements IMatrixArithmetic { | ||
@Override | ||
public double[][] multiply(double[][] matrixA, double[][] matrixB) { | ||
return new double[0][]; | ||
int columnsInFirstMatrix = matrixA[0].length; | ||
int rowsInSecondMatrix = matrixB.length; | ||
if (columnsInFirstMatrix != rowsInSecondMatrix) { | ||
throw new RuntimeException( | ||
"\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\""); | ||
} | ||
|
||
int rowsInMatrixC = matrixA.length; | ||
int columnsInMatrixC = matrixB.length; | ||
double[][] matrixC = new double[rowsInMatrixC][columnsInMatrixC]; | ||
for (int row = 0; row < rowsInMatrixC; row++) { | ||
for (int column = 0; column < columnsInMatrixC; column++) { | ||
matrixC[row][column] = | ||
vectorDotProduct(getRowVector(matrixA, row), getColumnVector(matrixB, column)); | ||
} | ||
} | ||
return matrixC; | ||
} | ||
|
||
public double[] getRowVector(double[][] matrix, int rowIndex) { | ||
return matrix[rowIndex]; | ||
} | ||
|
||
public double[] getColumnVector(double[][] matrix, int columnIndex) { | ||
return Arrays.stream(matrix).mapToDouble((row) -> row[columnIndex]).toArray(); | ||
} | ||
|
||
public double vectorDotProduct(double[] rowVector, double[] columVector) { | ||
int vectorSize = rowVector.length; | ||
double vectorDotProduct = 0; | ||
for (int i = 0; i < vectorSize; i++) { | ||
vectorDotProduct += rowVector[i] * columVector[i]; | ||
} | ||
return vectorDotProduct; | ||
} | ||
} |
156 changes: 156 additions & 0 deletions
156
lib/src/test/java/de/edux/core/math/matrix/strassen/MatrixArithmeticTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
package de.edux.core.math.matrix.strassen; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
import de.edux.core.math.IMatrixArithmetic; | ||
import de.edux.core.math.matrix.classic.MatrixArithmetic; | ||
import org.junit.jupiter.api.Test; | ||
|
||
class MatrixArithmeticTest { | ||
|
||
@Test | ||
public void testLargeMatrixMultiplication() { | ||
int matrixSize = 200; | ||
double[][] matrixA = new double[matrixSize][matrixSize]; | ||
double[][] matrixB = new double[matrixSize][matrixSize]; | ||
|
||
for (int i = 0; i < matrixSize; i++) { | ||
for (int j = 0; j < matrixSize; j++) { | ||
matrixA[i][j] = 1; | ||
matrixB[i][j] = 1; | ||
} | ||
} | ||
|
||
IMatrixArithmetic classic = new MatrixArithmetic(); | ||
|
||
double[][] result = classic.multiply(matrixA, matrixB); | ||
for (int i = 0; i < matrixSize; i++) { | ||
for (int j = 0; j < matrixSize; j++) { | ||
assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct."); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
public void shouldThrowRuntimeExceptionForMatricesWithIncompatibleSizes() { | ||
double[][] matrixA = { | ||
{3, -5, 1}, | ||
{-2, 0, 4}, | ||
{-1, 6, 5}, | ||
}; | ||
double[][] matrixB = { | ||
{7, 2, 4}, | ||
{0, 1, -5}, | ||
}; | ||
double[][] matrixC = { | ||
{3, -5}, | ||
{-2, 0}, | ||
{-1, 6}, | ||
}; | ||
|
||
IMatrixArithmetic classic = new MatrixArithmetic(); | ||
|
||
assertAll( | ||
() -> assertThrows(RuntimeException.class, () -> classic.multiply(matrixA, matrixB)), | ||
() -> assertThrows(RuntimeException.class, () -> classic.multiply(matrixC, matrixA))); | ||
} | ||
|
||
@Test | ||
public void shouldMultiply8x8MatricesCorrectly() { | ||
double[][] matrixA = { | ||
{1, 2, 3, 4, 5, 6, 7, 8}, | ||
{8, 7, 6, 5, 4, 3, 2, 1}, | ||
{2, 3, 4, 5, 6, 7, 8, 9}, | ||
{9, 8, 7, 6, 5, 4, 3, 2}, | ||
{1, 1, 1, 1, 1, 1, 1, 1}, | ||
{2, 2, 2, 2, 2, 2, 2, 2}, | ||
{1, 3, 5, 7, 9, 11, 13, 15}, | ||
{15, 13, 11, 9, 7, 5, 3, 1} | ||
}; | ||
double[][] matrixB = { | ||
{1, 0, 0, 0, 1, 0, 0, 0}, | ||
{0, 1, 0, 0, 0, 1, 0, 0}, | ||
{0, 0, 1, 0, 0, 0, 1, 0}, | ||
{0, 0, 0, 1, 0, 0, 0, 1}, | ||
{1, 0, 0, 0, 1, 0, 0, 0}, | ||
{0, 1, 0, 0, 0, 1, 0, 0}, | ||
{0, 0, 1, 0, 0, 0, 1, 0}, | ||
{0, 0, 0, 1, 0, 0, 0, 1} | ||
}; | ||
double[][] expected = { | ||
{6, 8, 10, 12, 6, 8, 10, 12}, | ||
{12, 10, 8, 6, 12, 10, 8, 6}, | ||
{8, 10, 12, 14, 8, 10, 12, 14}, | ||
{14, 12, 10, 8, 14, 12, 10, 8}, | ||
{2, 2, 2, 2, 2, 2, 2, 2}, | ||
{4, 4, 4, 4, 4, 4, 4, 4}, | ||
{10, 14, 18, 22, 10, 14, 18, 22}, | ||
{22, 18, 14, 10, 22, 18, 14, 10} | ||
}; | ||
|
||
IMatrixArithmetic matrixArithmetic = new MatrixArithmetic(); | ||
double[][] result = matrixArithmetic.multiply(matrixA, matrixB); | ||
|
||
assertArrayEquals( | ||
expected, result, "The 8x8 matrix multiplication did not yield the correct result."); | ||
} | ||
|
||
@Test | ||
public void shouldMultiplyWithZeroMatrix() { | ||
double[][] matrixA = { | ||
{0, 0}, | ||
{0, 0} | ||
}; | ||
double[][] matrixB = { | ||
{1, 2}, | ||
{3, 4} | ||
}; | ||
double[][] expected = { | ||
{0, 0}, | ||
{0, 0} | ||
}; | ||
|
||
IMatrixArithmetic matrixArithmetic = new MatrixArithmetic(); | ||
double[][] result = matrixArithmetic.multiply(matrixA, matrixB); | ||
|
||
assertArrayEquals(expected, result); | ||
} | ||
|
||
@Test | ||
public void shouldMultiplyWithIdentityMatrix() { | ||
double[][] matrixA = { | ||
{1, 0}, | ||
{0, 1} | ||
}; | ||
double[][] matrixB = { | ||
{5, 6}, | ||
{7, 8} | ||
}; | ||
|
||
IMatrixArithmetic matrixArithmetic = new MatrixArithmetic(); | ||
double[][] result = matrixArithmetic.multiply(matrixA, matrixB); | ||
|
||
assertArrayEquals(matrixB, result); | ||
} | ||
|
||
@Test | ||
public void shouldMultiplySmallMatricesCorrectly() { | ||
double[][] matrixA = { | ||
{1, 2}, | ||
{3, 4} | ||
}; | ||
double[][] matrixB = { | ||
{2, 0}, | ||
{1, 2} | ||
}; | ||
double[][] expected = { | ||
{4, 4}, | ||
{10, 8} | ||
}; | ||
|
||
IMatrixArithmetic matrixArithmetic = new MatrixArithmetic(); | ||
double[][] result = matrixArithmetic.multiply(matrixA, matrixB); | ||
|
||
assertArrayEquals(expected, result); | ||
} | ||
} |