Skip to content

Commit

Permalink
Merge pull request #94 from Samyssmile/testsadded
Browse files Browse the repository at this point in the history
tests(): Implement Multi-Threaded Matrix Multiplication #74
  • Loading branch information
Samyssmile authored Nov 6, 2023
2 parents e06c18f + 7a143c7 commit 990acf8
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
package de.edux.core.math.matrix.classic;

import de.edux.core.math.IMatrixArithmetic;
import java.util.Arrays;

public class MatrixArithmetic implements IMatrixArithmetic {
@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
return new double[0][];
int columnsInFirstMatrix = matrixA[0].length;
int rowsInSecondMatrix = matrixB.length;
if (columnsInFirstMatrix != rowsInSecondMatrix) {
throw new RuntimeException(
"\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\"");
}

int rowsInMatrixC = matrixA.length;
int columnsInMatrixC = matrixB.length;
double[][] matrixC = new double[rowsInMatrixC][columnsInMatrixC];
for (int row = 0; row < rowsInMatrixC; row++) {
for (int column = 0; column < columnsInMatrixC; column++) {
matrixC[row][column] =
vectorDotProduct(getRowVector(matrixA, row), getColumnVector(matrixB, column));
}
}
return matrixC;
}

public double[] getRowVector(double[][] matrix, int rowIndex) {
return matrix[rowIndex];
}

public double[] getColumnVector(double[][] matrix, int columnIndex) {
return Arrays.stream(matrix).mapToDouble((row) -> row[columnIndex]).toArray();
}

public double vectorDotProduct(double[] rowVector, double[] columVector) {
int vectorSize = rowVector.length;
double vectorDotProduct = 0;
for (int i = 0; i < vectorSize; i++) {
vectorDotProduct += rowVector[i] * columVector[i];
}
return vectorDotProduct;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package de.edux.core.math.matrix.strassen;

import static org.junit.jupiter.api.Assertions.*;

import de.edux.core.math.IMatrixArithmetic;
import de.edux.core.math.matrix.classic.MatrixArithmetic;
import org.junit.jupiter.api.Test;

class MatrixArithmeticTest {

@Test
public void testLargeMatrixMultiplication() {
int matrixSize = 200;
double[][] matrixA = new double[matrixSize][matrixSize];
double[][] matrixB = new double[matrixSize][matrixSize];

for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
matrixA[i][j] = 1;
matrixB[i][j] = 1;
}
}

IMatrixArithmetic classic = new MatrixArithmetic();

double[][] result = classic.multiply(matrixA, matrixB);
for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct.");
}
}
}

@Test
public void shouldThrowRuntimeExceptionForMatricesWithIncompatibleSizes() {
double[][] matrixA = {
{3, -5, 1},
{-2, 0, 4},
{-1, 6, 5},
};
double[][] matrixB = {
{7, 2, 4},
{0, 1, -5},
};
double[][] matrixC = {
{3, -5},
{-2, 0},
{-1, 6},
};

IMatrixArithmetic classic = new MatrixArithmetic();

assertAll(
() -> assertThrows(RuntimeException.class, () -> classic.multiply(matrixA, matrixB)),
() -> assertThrows(RuntimeException.class, () -> classic.multiply(matrixC, matrixA)));
}

@Test
public void shouldMultiply8x8MatricesCorrectly() {
double[][] matrixA = {
{1, 2, 3, 4, 5, 6, 7, 8},
{8, 7, 6, 5, 4, 3, 2, 1},
{2, 3, 4, 5, 6, 7, 8, 9},
{9, 8, 7, 6, 5, 4, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{1, 3, 5, 7, 9, 11, 13, 15},
{15, 13, 11, 9, 7, 5, 3, 1}
};
double[][] matrixB = {
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1},
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1}
};
double[][] expected = {
{6, 8, 10, 12, 6, 8, 10, 12},
{12, 10, 8, 6, 12, 10, 8, 6},
{8, 10, 12, 14, 8, 10, 12, 14},
{14, 12, 10, 8, 14, 12, 10, 8},
{2, 2, 2, 2, 2, 2, 2, 2},
{4, 4, 4, 4, 4, 4, 4, 4},
{10, 14, 18, 22, 10, 14, 18, 22},
{22, 18, 14, 10, 22, 18, 14, 10}
};

IMatrixArithmetic matrixArithmetic = new MatrixArithmetic();
double[][] result = matrixArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(
expected, result, "The 8x8 matrix multiplication did not yield the correct result.");
}

@Test
public void shouldMultiplyWithZeroMatrix() {
double[][] matrixA = {
{0, 0},
{0, 0}
};
double[][] matrixB = {
{1, 2},
{3, 4}
};
double[][] expected = {
{0, 0},
{0, 0}
};

IMatrixArithmetic matrixArithmetic = new MatrixArithmetic();
double[][] result = matrixArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}

@Test
public void shouldMultiplyWithIdentityMatrix() {
double[][] matrixA = {
{1, 0},
{0, 1}
};
double[][] matrixB = {
{5, 6},
{7, 8}
};

IMatrixArithmetic matrixArithmetic = new MatrixArithmetic();
double[][] result = matrixArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(matrixB, result);
}

@Test
public void shouldMultiplySmallMatricesCorrectly() {
double[][] matrixA = {
{1, 2},
{3, 4}
};
double[][] matrixB = {
{2, 0},
{1, 2}
};
double[][] expected = {
{4, 4},
{10, 8}
};

IMatrixArithmetic matrixArithmetic = new MatrixArithmetic();
double[][] result = matrixArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}
}

0 comments on commit 990acf8

Please sign in to comment.