Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refracor(): refractor code in classic MatrixArithmetic #99

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,44 +1,39 @@
package de.edux.core.math.matrix.classic;

import de.edux.core.math.IMatrixArithmetic;
import java.util.Arrays;
import java.util.stream.IntStream;

public class MatrixArithmetic implements IMatrixArithmetic {
@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
int columnsInFirstMatrix = matrixA[0].length;
int rowsInSecondMatrix = matrixB.length;
if (columnsInFirstMatrix != rowsInSecondMatrix) {

private void checkSizeForMultiplication(double[][] matrixA, double[][] matrixB) {
int m = matrixB.length;
int n = matrixA[0].length;
if (m != n) {
throw new RuntimeException(
"\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\"");
}

int rowsInMatrixC = matrixA.length;
int columnsInMatrixC = matrixB.length;
double[][] matrixC = new double[rowsInMatrixC][columnsInMatrixC];
for (int row = 0; row < rowsInMatrixC; row++) {
for (int column = 0; column < columnsInMatrixC; column++) {
matrixC[row][column] =
vectorDotProduct(getRowVector(matrixA, row), getColumnVector(matrixB, column));
}
}
return matrixC;
}

public double[] getRowVector(double[][] matrix, int rowIndex) {
return matrix[rowIndex];
}
@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
checkSizeForMultiplication(matrixA, matrixB);

public double[] getColumnVector(double[][] matrix, int columnIndex) {
return Arrays.stream(matrix).mapToDouble((row) -> row[columnIndex]).toArray();
}
int aRows = matrixA.length;
int aColumns = matrixA[0].length;
int bColumns = matrixB[0].length;

public double vectorDotProduct(double[] rowVector, double[] columVector) {
int vectorSize = rowVector.length;
double vectorDotProduct = 0;
for (int i = 0; i < vectorSize; i++) {
vectorDotProduct += rowVector[i] * columVector[i];
}
return vectorDotProduct;
double[][] result = new double[aRows][bColumns];

IntStream.range(0, aRows)
.forEach(
row -> {
for (int col = 0; col < bColumns; col++) {
for (int i = 0; i < aColumns; i++) {
result[row][col] += matrixA[row][i] * matrixB[i][col];
}
}
});

return result;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
package de.edux.core.math.matrix.parallel;

import de.edux.core.math.IMatrixArithmetic;
import java.util.stream.IntStream;

public class MatrixParallelArithmetic implements IMatrixArithmetic {

private void checkSizeForMultiplication(double[][] matrixA, double[][] matrixB) {
int m = matrixB.length;
int n = matrixA[0].length;
if (m != n) {
throw new RuntimeException(
"\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\"");
}
}

@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
return new double[0][];
checkSizeForMultiplication(matrixA, matrixB);

int aRows = matrixA.length;
int aColumns = matrixA[0].length;
int bColumns = matrixB[0].length;

double[][] result = new double[aRows][bColumns];

IntStream.range(0, aRows)
.parallel()
.forEach(
row -> {
for (int col = 0; col < bColumns; col++) {
for (int i = 0; i < aColumns; i++) {
result[row][col] += matrixA[row][i] * matrixB[i][col];
}
}
});

return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package de.edux.core.math.matrix.parallel;

import static org.junit.jupiter.api.Assertions.*;

import de.edux.core.math.IMatrixArithmetic;
import org.junit.jupiter.api.Test;

class MatrixParallelArithmeticTest {

@Test
public void testLargeMatrixMultiplication() {
int matrixSize = 200;
double[][] matrixA = new double[matrixSize][matrixSize];
double[][] matrixB = new double[matrixSize][matrixSize];

for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
matrixA[i][j] = 1;
matrixB[i][j] = 1;
}
}
IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);
for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct.");
}
}
}

@Test
public void shouldThrowRuntimeExceptionForMatricesWithIncompatibleSizes() {
double[][] matrixA = {
{3, -5, 1},
{-2, 0, 4},
{-1, 6, 5},
};
double[][] matrixB = {
{7, 2, 4},
{0, 1, -5},
};
double[][] matrixC = {
{3, -5},
{-2, 0},
{-1, 6},
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic();

assertAll(
() ->
assertThrows(
RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixA, matrixB)),
() ->
assertThrows(
RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixC, matrixA)));
}

@Test
public void shouldMultiply8x8MatricesCorrectly() {
double[][] matrixA = {
{1, 2, 3, 4, 5, 6, 7, 8},
{8, 7, 6, 5, 4, 3, 2, 1},
{2, 3, 4, 5, 6, 7, 8, 9},
{9, 8, 7, 6, 5, 4, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{1, 3, 5, 7, 9, 11, 13, 15},
{15, 13, 11, 9, 7, 5, 3, 1}
};
double[][] matrixB = {
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1},
{1, 0, 0, 0, 1, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 0, 0},
{0, 0, 1, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 0, 0, 1}
};
double[][] expected = {
{6, 8, 10, 12, 6, 8, 10, 12},
{12, 10, 8, 6, 12, 10, 8, 6},
{8, 10, 12, 14, 8, 10, 12, 14},
{14, 12, 10, 8, 14, 12, 10, 8},
{2, 2, 2, 2, 2, 2, 2, 2},
{4, 4, 4, 4, 4, 4, 4, 4},
{10, 14, 18, 22, 10, 14, 18, 22},
{22, 18, 14, 10, 22, 18, 14, 10}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(
expected, result, "The 8x8 matrix multiplication did not yield the correct result.");
}

@Test
public void shouldMultiplyWithZeroMatrix() {
double[][] matrixA = {
{0, 0},
{0, 0}
};
double[][] matrixB = {
{1, 2},
{3, 4}
};
double[][] expected = {
{0, 0},
{0, 0}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}

@Test
public void shouldMultiplyWithIdentityMatrix() {
double[][] matrixA = {
{1, 0},
{0, 1}
};
double[][] matrixB = {
{5, 6},
{7, 8}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(matrixB, result);
}

@Test
public void shouldMultiplySmallMatricesCorrectly() {
double[][] matrixA = {
{1, 2},
{3, 4}
};
double[][] matrixB = {
{2, 0},
{1, 2}
};
double[][] expected = {
{4, 4},
{10, 8}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixParallelArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

assertArrayEquals(expected, result);
}
}