Skip to content

Commit

Permalink
Merge pull request #110 from acsolle66/feature-matrix-vector-product
Browse files Browse the repository at this point in the history
feat(#101): implement MatrixVectorProduct
  • Loading branch information
Samyssmile authored Nov 9, 2023
2 parents 34562ac + ad93c9a commit 2e43d6f
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class DecisionTreeExampleOnPenguinsDatasetNoHead {
+ "seaborn-penguins"
+ File.separator
+ "penguins-no-head.csv");
private static final boolean SKIP_HEAD = true;
private static final boolean SKIP_HEAD = false;
private static final ImputationStrategy averageImputation = ImputationStrategy.AVERAGE;
private static final ImputationStrategy modeImputation = ImputationStrategy.MODE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,30 @@ public MatrixArithmetic() {

@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {

if (matrixA == null || matrixB == null) {
throw new IllegalArgumentException("Matrices must not be null.");
}
if (matrixA.length == 0 || matrixB.length == 0) {
throw new IllegalArgumentException("Matrices must not be empty.");
}
if (matrixA[0].length != matrixB.length) {
throw new IllegalArgumentException("Matrix A columns must match Matrix B rows.");
}

return matrixProduct.multiply(matrixA, matrixB);
}

@Override
public double[] multiply(double[][] matrix, double[] vector) {

if (matrix.length == 0 || matrix[0].length == 0 || vector.length == 0) {
throw new IllegalArgumentException("Matrix and vector must not be empty.");
}
if (matrix[0].length != vector.length) {
throw new IllegalArgumentException("Matrix columns and vector size do not match.");
}

return matrixVectorProduct.multiply(matrix, vector);
}
}
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
package de.edux.core.math.matrix.parallel.operations;

import de.edux.core.math.IMatrixProduct;

import java.util.stream.IntStream;

public class MatrixProduct implements IMatrixProduct {

private void checkSizeForMultiplication(double[][] matrixA, double[][] matrixB) {
int m = matrixB.length;
int n = matrixA[0].length;
if (m != n) {
throw new RuntimeException(
"\"The number of columns in the first matrix must be equal to the number of rows in the second matrix.\"");
}
}

@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
checkSizeForMultiplication(matrixA, matrixB);

int aRows = matrixA.length;
int aColumns = matrixA[0].length;
int bColumns = matrixB[0].length;

double[][] result = new double[aRows][bColumns];

IntStream.range(0, aRows)
.parallel()
.forEach(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
package de.edux.core.math.matrix.parallel.operations;

import de.edux.core.math.IMatrixVectorProduct;
import java.util.stream.IntStream;

public class MatrixVectorProduct implements IMatrixVectorProduct {

@Override
public double[] multiply(double[][] matrix, double[] vector) {
return new double[0];
int matrixRows = matrix.length;
int matrixColumns = matrix[0].length;

double[] result = new double[matrixRows];
IntStream.range(0, matrixRows)
.parallel()
.forEach(
row -> {
for (int col = 0; col < matrixColumns; col++) {
result[row] += matrix[row][col] * vector[col];
}
});

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

import static org.junit.jupiter.api.Assertions.*;

import de.edux.core.math.IMatrixArithmetic;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

class MatrixArithmeticTest {

private static MatrixArithmetic matrixArithmetic;

@BeforeAll
static void setUp() {
matrixArithmetic = new MatrixArithmetic();
}

@Test
public void testLargeMatrixMultiplication() {
int matrixSize = 200;
Expand All @@ -19,8 +26,8 @@ public void testLargeMatrixMultiplication() {
matrixB[i][j] = 1;
}
}
IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

double[][] result = matrixArithmetic.multiply(matrixA, matrixB);
for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
assertEquals(matrixSize, result[i][j], "Result on [" + i + "][" + j + "] not correct.");
Expand All @@ -45,15 +52,13 @@ public void shouldThrowRuntimeExceptionForMatricesWithIncompatibleSizes() {
{-1, 6},
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic();

assertAll(
() ->
assertThrows(
RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixA, matrixB)),
RuntimeException.class, () -> matrixArithmetic.multiply(matrixA, matrixB)),
() ->
assertThrows(
RuntimeException.class, () -> matrixParallelArithmetic.multiply(matrixC, matrixA)));
RuntimeException.class, () -> matrixArithmetic.multiply(matrixC, matrixA)));
}

@Test
Expand Down Expand Up @@ -89,9 +94,7 @@ public void shouldMultiply8x8MatricesCorrectly() {
{22, 18, 14, 10, 22, 18, 14, 10}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

double[][] result = matrixArithmetic.multiply(matrixA, matrixB);
assertArrayEquals(
expected, result, "The 8x8 matrix multiplication did not yield the correct result.");
}
Expand All @@ -111,9 +114,7 @@ public void shouldMultiplyWithZeroMatrix() {
{0, 0}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

double[][] result = matrixArithmetic.multiply(matrixA, matrixB);
assertArrayEquals(expected, result);
}

Expand All @@ -128,9 +129,7 @@ public void shouldMultiplyWithIdentityMatrix() {
{7, 8}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

double[][] result = matrixArithmetic.multiply(matrixA, matrixB);
assertArrayEquals(matrixB, result);
}

Expand All @@ -149,9 +148,56 @@ public void shouldMultiplySmallMatricesCorrectly() {
{10, 8}
};

IMatrixArithmetic matrixParallelArithmetic = new MatrixArithmetic();
double[][] result = matrixParallelArithmetic.multiply(matrixA, matrixB);

double[][] result = matrixArithmetic.multiply(matrixA, matrixB);
assertArrayEquals(expected, result);
}

@Test
void shouldSolveMatrixVectorProduct() {
int matrixSize = 2048;
double[][] matrixA = new double[matrixSize][matrixSize];
double[] vector = new double[matrixSize];

for (int i = 0; i < matrixSize; i++) {
for (int j = 0; j < matrixSize; j++) {
matrixA[i][j] = 1;
vector[i] = 1;
}
}

double[] resultVector = matrixArithmetic.multiply(matrixA, vector);
for (int i = 0; i < matrixSize; i++) {
assertEquals(matrixSize, resultVector[i], "Result on [" + i + "][" + i + "] not correct.");
}
}

@Test
void shouldHandleEmptyMatrix() {
double[][] matrix = new double[0][0];
double[] vector = new double[0];
assertThrows(IllegalArgumentException.class, () -> matrixArithmetic.multiply(matrix, vector));
}

@Test
void shouldHandleMismatchedSizes() {
double[][] matrix = {{1, 2, 3}, {4, 5, 6}};
double[] vector = {1, 2};
assertThrows(IllegalArgumentException.class, () -> matrixArithmetic.multiply(matrix, vector));
}

@Test
void shouldHandleNullMatrix() {
double[] vector = {1, 2, 3};
assertThrows(NullPointerException.class, () -> matrixArithmetic.multiply(null, vector));
}

@Test
void shouldMultiplyVectorWithIdentityMatrix() {
double[][] matrix = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}};
double[] vector = {1, 2, 3};
double[] expected = {1, 2, 3};
double[] result = matrixArithmetic.multiply(matrix, vector);
assertArrayEquals(
expected, result, "Multiplying with identity matrix should return the original vector.");
}
}

0 comments on commit 2e43d6f

Please sign in to comment.