Skip to content

Commit

Permalink
feat#23 add test to imputation methods and fix bug in AverageImputation
Browse files Browse the repository at this point in the history
Average imputation throwed RuntimeException beacuse the condition in the isDigit() method
was returning false for blank values.Fixed with an or gate.
  • Loading branch information
acsolle66 committed Oct 30, 2023
1 parent bfafba4 commit db48b94
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package de.edux.functions.imputation;

import java.util.Arrays;
import java.util.List;

public class AverageImputation implements IImputationStrategy {
@Override
Expand All @@ -23,25 +24,23 @@ public String[] performImputation(String[] datasetColumn) {

private void checkIfColumnContainsCategoricalValues(String[] datasetColumn) {
for (String value : datasetColumn) {
if (!isNumeric(value)) {
if(!isNumeric(value)){
throw new RuntimeException(
"AVERAGE imputation strategy can not be used on categorical features. "
+ "Use MODE imputation strategy or perform a list wise deletion on the features.");
"AVERAGE imputation strategy can not be used on categorical features. "
+ "Use MODE imputation strategy or perform a list wise deletion on the features.");
}
}
}

private boolean isNumeric(String value) {
return value.matches("-?\\d+(\\.\\d+)?");
return value.matches("-?\\d+(\\.\\d+)?") || value.isBlank();
}

private double calculateAverage(String[] datasetColumn) {
String[] filteredDatasetColumn =
(String[]) Arrays.stream(datasetColumn).filter((value) -> !value.isBlank()).toArray();
int validValueCount = filteredDatasetColumn.length;
List<String> filteredDatasetColumn = Arrays.stream(datasetColumn).filter((value) -> !value.isBlank()).toList();
int validValueCount = filteredDatasetColumn.size();
double sumOfValidValues =
Arrays.stream(filteredDatasetColumn).map(Double::parseDouble).reduce(0.0, Double::sum);

filteredDatasetColumn.stream().map(Double::parseDouble).reduce(0.0, Double::sum);
return sumOfValidValues / validValueCount;
}
}
38 changes: 38 additions & 0 deletions lib/src/test/java/de/edux/data/provider/DataProcessorTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package de.edux.data.provider;

import de.edux.data.reader.IDataReader;
import de.edux.functions.imputation.IImputationStrategy;
import de.edux.functions.imputation.ImputationStrategy;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -10,6 +12,7 @@

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand All @@ -23,6 +26,7 @@ class DataProcessorTest {

private static final boolean SKIP_HEAD = true;
private List<String[]> dummyDataset;
List<String[]> dummyDatasetForImputationTest;

private DataProcessor dataProcessor;

Expand All @@ -39,6 +43,18 @@ void setUp() {
dummyDataset.add(new String[]{"16", "17", "18", "Isabela", "20"});
when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset);
dataProcessor = new DataProcessor(dataReader);

dummyDatasetForImputationTest = new ArrayList<>();
dummyDatasetForImputationTest.add(new String[] {"Fruit", "Quantity", "Price"});
dummyDatasetForImputationTest.add(new String[] {"Apple", "", "8"});
dummyDatasetForImputationTest.add(new String[] {"Apple", "2", "9"});
dummyDatasetForImputationTest.add(new String[] {"", "3", "10"});
dummyDatasetForImputationTest.add(new String[] {"Peach", "3", ""});
dummyDatasetForImputationTest.add(new String[] {"Kiwi", "5", ""});
dummyDatasetForImputationTest.add(new String[] {"", "3", "11"});
dummyDatasetForImputationTest.add(new String[] {"Banana", "7", "12"});
when(dataReader.readFile(new File("imputation"), ',')).thenReturn(dummyDatasetForImputationTest);
dataProcessor = new DataProcessor(dataReader);
}

@Test
Expand Down Expand Up @@ -156,4 +172,26 @@ void shouldReturnTrainTestDataset() {

}

@Test
void shouldPerformImputationOnDataset() {
dataProcessor.loadDataSetFromCSV(new File("imputation"), ',', true, new int[]{0, 1}, 2);

ImputationStrategy modeImputter = ImputationStrategy.MODE;
ImputationStrategy averageImputter = ImputationStrategy.AVERAGE;

dataProcessor.imputation("Fruit",modeImputter);
dataProcessor.imputation("Quantity",modeImputter);
dataProcessor.imputation("Price",averageImputter);
var imputted_dataset = dataProcessor.getDataset();

assertAll(
() -> assertArrayEquals(new String[] {"Apple", "3", "8"}, imputted_dataset.get(0)),
() -> assertArrayEquals(new String[] {"Apple", "2", "9"}, imputted_dataset.get(1)),
() -> assertArrayEquals(new String[] {"Apple", "3", "10"}, imputted_dataset.get(2)),
() -> assertArrayEquals(new String[] {"Peach", "3", "10.0"}, imputted_dataset.get(3)),
() -> assertArrayEquals(new String[] {"Kiwi", "5", "10.0"}, imputted_dataset.get(4)),
() -> assertArrayEquals(new String[] {"Apple", "3", "11"}, imputted_dataset.get(5)),
() -> assertArrayEquals(new String[] {"Banana", "7", "12"}, imputted_dataset.get(6))
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package de.edux.functions.imputation;

import static org.junit.jupiter.api.Assertions.*;

import org.junit.jupiter.api.Test;

class AverageImputationTest {
@Test
void performImputationWithCategoricalValuesShouldThrowRuntimeException() {
String[] categorical_features = {"A", "B", "C"};
assertThrows(RuntimeException.class,() -> new AverageImputation().performImputation(categorical_features));
}

@Test
void performImputationWithNumericalValuesTest() {
String[] numerical_features_with_missing_values = {"1", "","2", "3", "", "4"};
AverageImputation imputter = new AverageImputation();
String[] numerical_features_with_imputted_values = imputter.performImputation(numerical_features_with_missing_values);
assertAll(
() -> assertEquals("2.5", numerical_features_with_imputted_values[1]),
() -> assertEquals("2.5", numerical_features_with_imputted_values[4])
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package de.edux.functions.imputation;

import static org.junit.jupiter.api.Assertions.*;

import org.junit.jupiter.api.Test;

class ModeImputationTest {

@Test
void performImputationWithNumericalValuesTest() {
String[] numerical_features_with_missing_values = {"1", "","1", "2", "", "3"};
ModeImputation imputter = new ModeImputation();
String[] numerical_features_with_imputted_values = imputter.performImputation(numerical_features_with_missing_values);
assertAll(
() -> assertEquals("1", numerical_features_with_imputted_values[1]),
() -> assertEquals("1", numerical_features_with_imputted_values[4])
);
}

@Test
void performImputationWithCategoricalValuesTest() {
String[] numerical_features_with_missing_values = {"A", "","A", "B", "", "C"};
ModeImputation imputter = new ModeImputation();
String[] numerical_features_with_imputted_values = imputter.performImputation(numerical_features_with_missing_values);
assertAll(
() -> assertEquals("A", numerical_features_with_imputted_values[1]),
() -> assertEquals("A", numerical_features_with_imputted_values[4])
);
}
}

0 comments on commit db48b94

Please sign in to comment.