diff --git a/example/src/main/java/de/example/benchmark/Benchmark.java b/example/src/main/java/de/example/benchmark/Benchmark.java index fb93600..4df94cc 100644 --- a/example/src/main/java/de/example/benchmark/Benchmark.java +++ b/example/src/main/java/de/example/benchmark/Benchmark.java @@ -1,6 +1,7 @@ package de.example.benchmark; import de.edux.api.Classifier; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.functions.activation.ActivationFunction; import de.edux.functions.initialization.Initialization; import de.edux.functions.loss.LossFunction; @@ -27,7 +28,7 @@ public class Benchmark { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); private double[][] trainFeatures; @@ -127,7 +128,7 @@ private void updateMLP(double[][] testFeatures, double[][] testLabels) { private void initFeaturesAndLabels() { var seabornDataProcessor = new SeabornDataProcessor(); - List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY); seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); trainFeatures = seabornDataProcessor.getTrainFeatures(); diff --git a/example/src/main/java/de/example/knn/KnnIrisExample.java b/example/src/main/java/de/example/knn/KnnIrisExample.java index e4d1850..1cba308 100644 --- a/example/src/main/java/de/example/knn/KnnIrisExample.java +++ b/example/src/main/java/de/example/knn/KnnIrisExample.java @@ -1,8 +1,8 @@ package de.example.knn; import de.edux.api.Classifier; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.ml.knn.KnnClassifier; -import de.edux.ml.nn.network.api.Dataset; import de.example.data.iris.Iris; import de.example.data.iris.IrisDataProcessor; @@ -17,13 +17,13 @@ public class KnnIrisExample { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); public static void main(String[] args) { var irisDataProcessor = new IrisDataProcessor(); - List data = irisDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + List data = irisDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY); irisDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); Classifier knn = new KnnClassifier(2); diff --git a/example/src/main/java/de/example/knn/KnnSeabornExample.java b/example/src/main/java/de/example/knn/KnnSeabornExample.java index 324d438..34d7be6 100644 --- a/example/src/main/java/de/example/knn/KnnSeabornExample.java +++ b/example/src/main/java/de/example/knn/KnnSeabornExample.java @@ -1,6 +1,7 @@ package de.example.knn; import de.edux.api.Classifier; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.ml.knn.KnnClassifier; import de.edux.ml.nn.network.api.Dataset; import de.example.data.seaborn.Penguin; @@ -17,13 +18,13 @@ public class KnnSeabornExample { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); public static void main(String[] args) { var seabornDataProcessor = new SeabornDataProcessor(); - List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY); Dataset dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData()); diff --git a/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java b/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java index 9afcb6f..b42e466 100644 --- a/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java +++ b/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java @@ -2,6 +2,7 @@ import de.edux.api.Classifier; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.functions.activation.ActivationFunction; import de.edux.functions.initialization.Initialization; import de.edux.functions.loss.LossFunction; @@ -18,13 +19,13 @@ public class MultilayerPerceptronSeabornExample { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); public static void main(String[] args) { var seabornDataProcessor = new SeabornDataProcessor(); - List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY); Dataset dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData()); diff --git a/lib/src/main/java/de/edux/data/handler/AverageFillIncompleteRecordsHandler.java b/lib/src/main/java/de/edux/data/handler/AverageFillIncompleteRecordsHandler.java new file mode 100644 index 0000000..8234e4d --- /dev/null +++ b/lib/src/main/java/de/edux/data/handler/AverageFillIncompleteRecordsHandler.java @@ -0,0 +1,110 @@ +package de.edux.data.handler; + +import java.util.ArrayList; +import java.util.List; + +public class AverageFillIncompleteRecordsHandler implements IIncompleteRecordsHandler { + @Override + public List getCleanedDataset(List dataset) { + List typeOfFeatures = getFeatureTypes(dataset); + List cleanedDataset = + dropRecordsWithIncompleteCategoricalFeature(dataset, typeOfFeatures); + + return averageFillRecordsWithIncompleteNumericalFeature(cleanedDataset, typeOfFeatures); + } + + private List averageFillRecordsWithIncompleteNumericalFeature( + List dataset, List typeOfFeatures) { + for (int columnIndex = 0; columnIndex < typeOfFeatures.size(); columnIndex++) { + int validFeatureCount = 0; + double sum = 0; + double average; + + if (typeOfFeatures.get(columnIndex).equals("numerical")) { + for (String[] record : dataset) { + if (isCompleteFeature(record[columnIndex])) { + validFeatureCount++; + sum += Double.parseDouble(record[columnIndex]); + } + } + + if (validFeatureCount < dataset.size() * 0.5) { + throw new RuntimeException( + "Less than 50% of the records will be used to calculate the fill values. " + + "Consider using another IncompleteRecordsHandlerStrategy or handle this exception."); + } + + average = sum / validFeatureCount; + for (String[] record : dataset) { + if (!isCompleteFeature(record[columnIndex])) { + record[columnIndex] = String.valueOf(average); + } + } + } + } + + return dataset; + } + + private List dropRecordsWithIncompleteCategoricalFeature( + List dataset, List typeOfFeatures) { + List cleanedDataset = dataset; + + for (int columnIndex = 0; columnIndex < typeOfFeatures.size(); columnIndex++) { + if (typeOfFeatures.get(columnIndex).equals("categorical")) { + int columnIndexFin = columnIndex; + cleanedDataset = + cleanedDataset.stream() + .filter(record -> isCompleteFeature(record[columnIndexFin])) + .toList(); + } + } + + if (cleanedDataset.size() < dataset.size() * 0.5) { + throw new RuntimeException( + "More than 50% of the records will be dropped with this IncompleteRecordsHandlerStrategy. " + + "Consider using another IncompleteRecordsHandlerStrategy or handle this exception."); + } + + return cleanedDataset; + } + + private List getFeatureTypes(List dataset) { + List featureTypes = new ArrayList<>(); + for (String[] record : dataset) { + if (containsIncompleteFeature(record)) { + continue; + } + for (String feature : record) { + if (isNumeric(feature)) { + featureTypes.add("numerical"); + } else { + featureTypes.add("categorical"); + } + } + break; + } + + if (featureTypes.isEmpty()) { + throw new RuntimeException("At least one full record needed with valid features"); + } + return featureTypes; + } + + private boolean isNumeric(String feature) { + return feature.matches("-?\\d+(\\.\\d+)?"); + } + + private boolean isCompleteFeature(String feature) { + return !feature.isBlank(); + } + + private boolean containsIncompleteFeature(String[] record) { + for (String feature : record) { + if (feature.isBlank()) { + return true; + } + } + return false; + } +} diff --git a/lib/src/main/java/de/edux/data/handler/DoNotHandleIncompleteRecords.java b/lib/src/main/java/de/edux/data/handler/DoNotHandleIncompleteRecords.java new file mode 100644 index 0000000..3a54f06 --- /dev/null +++ b/lib/src/main/java/de/edux/data/handler/DoNotHandleIncompleteRecords.java @@ -0,0 +1,10 @@ +package de.edux.data.handler; + +import java.util.List; + +public class DoNotHandleIncompleteRecords implements IIncompleteRecordsHandler { + @Override + public List getCleanedDataset(List dataset) { + return dataset; + } +} diff --git a/lib/src/main/java/de/edux/data/handler/DropIncompleteRecordsHandler.java b/lib/src/main/java/de/edux/data/handler/DropIncompleteRecordsHandler.java new file mode 100644 index 0000000..92a6592 --- /dev/null +++ b/lib/src/main/java/de/edux/data/handler/DropIncompleteRecordsHandler.java @@ -0,0 +1,33 @@ +package de.edux.data.handler; + +import java.util.ArrayList; +import java.util.List; + +public class DropIncompleteRecordsHandler implements IIncompleteRecordsHandler { + @Override + public List getCleanedDataset(List dataset) { + List filteredList = + dataset.stream().filter(this::containsOnlyCompletedFeatures).toList(); + + if (filteredList.size() < dataset.size() * 0.5) { + throw new RuntimeException( + "More than 50% of the records will be dropped with this IncompleteRecordsHandlerStrategy. " + + "Consider using another IncompleteRecordsHandlerStrategy or handle this exception."); + } + + List cleanedDataset = new ArrayList<>(); + for (String[] item : filteredList) { + cleanedDataset.add(item); + } + return cleanedDataset; + } + + private boolean containsOnlyCompletedFeatures(String[] record) { + for (String feature : record) { + if (feature.isBlank()) { + return false; + } + } + return true; + } +} diff --git a/lib/src/main/java/de/edux/data/handler/EIncompleteRecordsHandlerStrategy.java b/lib/src/main/java/de/edux/data/handler/EIncompleteRecordsHandlerStrategy.java new file mode 100644 index 0000000..1a1269b --- /dev/null +++ b/lib/src/main/java/de/edux/data/handler/EIncompleteRecordsHandlerStrategy.java @@ -0,0 +1,17 @@ +package de.edux.data.handler; + +public enum EIncompleteRecordsHandlerStrategy { + DO_NOT_HANDLE(new DoNotHandleIncompleteRecords()), + DROP_RECORDS(new DropIncompleteRecordsHandler()), + FILL_RECORDS_WITH_AVERAGE(new AverageFillIncompleteRecordsHandler()); + + private final IIncompleteRecordsHandler incompleteRecordHandler; + + EIncompleteRecordsHandlerStrategy(IIncompleteRecordsHandler incompleteRecordHandler) { + this.incompleteRecordHandler = incompleteRecordHandler; + } + + public IIncompleteRecordsHandler getHandler() { + return this.incompleteRecordHandler; + } +} diff --git a/lib/src/main/java/de/edux/data/handler/IIncompleteRecordsHandler.java b/lib/src/main/java/de/edux/data/handler/IIncompleteRecordsHandler.java new file mode 100644 index 0000000..2c5c367 --- /dev/null +++ b/lib/src/main/java/de/edux/data/handler/IIncompleteRecordsHandler.java @@ -0,0 +1,7 @@ +package de.edux.data.handler; + +import java.util.List; + +public interface IIncompleteRecordsHandler { + List getCleanedDataset(List dataset); +} diff --git a/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java b/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java index 263276c..dfc4031 100644 --- a/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java +++ b/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java @@ -1,6 +1,7 @@ package de.edux.data.provider; import java.util.List; +import java.util.Optional; public abstract class DataPostProcessor { public abstract void normalize(List rowDataset); @@ -21,4 +22,10 @@ public abstract class DataPostProcessor { public abstract double[][] getTestFeatures(); + public abstract Optional getIndexOfColumn(String columnName); + + public abstract String[] getColumnDataOf(String columnName); + + public abstract String[] getColumnNames(); + } diff --git a/lib/src/main/java/de/edux/data/provider/DataProcessor.java b/lib/src/main/java/de/edux/data/provider/DataProcessor.java index e1f13bc..7cb29c4 100644 --- a/lib/src/main/java/de/edux/data/provider/DataProcessor.java +++ b/lib/src/main/java/de/edux/data/provider/DataProcessor.java @@ -1,5 +1,6 @@ package de.edux.data.provider; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.data.reader.CSVIDataReader; import de.edux.data.reader.IDataReader; import de.edux.ml.nn.network.api.Dataset; @@ -10,12 +11,15 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; public abstract class DataProcessor extends DataPostProcessor implements IDataUtil { private static final Logger LOG = LoggerFactory.getLogger(DataProcessor.class); private final IDataReader csvDataReader; private ArrayList dataset; private Dataset splitedDataset; + private String[] columnNames; + private List rawDataset; public DataProcessor() { this.csvDataReader = new CSVIDataReader(); @@ -26,22 +30,24 @@ public DataProcessor(IDataReader csvDataReader) { } @Override - public List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords) { - List x = csvDataReader.readFile(csvFile, csvSeparator); - List unmodifiableDataset = csvDataReader.readFile(csvFile, csvSeparator) + public List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHeadline, boolean shuffle, EIncompleteRecordsHandlerStrategy incompleteRecordHandlerStrategy) { + rawDataset = csvDataReader.readFile(csvFile, csvSeparator); + + if (skipHeadline) { + columnNames = rawDataset.remove(0); + } else { + columnNames = rawDataset.get(0); + } + List csvDataset = incompleteRecordHandlerStrategy.getHandler().getCleanedDataset(rawDataset); + + List unmodifiableDataset = csvDataset .stream() .map(this::mapToDataRecord) - .filter(record -> !filterIncompleteRecords || record != null) .toList(); dataset = new ArrayList<>(unmodifiableDataset); LOG.info("Dataset loaded"); - if (normalize) { - normalize(dataset); - LOG.info("Dataset normalized"); - } - if (shuffle) { Collections.shuffle(dataset); LOG.info("Dataset shuffled"); @@ -52,7 +58,7 @@ public List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean norma /** * Split data into train and test data * - * @param data data to split + * @param data data to split * @param trainTestSplitRatio ratio of train data * @return list of train and test data. First element is train data, second element is test data. */ @@ -70,6 +76,7 @@ public Dataset split(List data, double trainTestSplitRatio) { splitedDataset = new Dataset<>(trainDataset, testDataset); return splitedDataset; } + public ArrayList getDataset() { return dataset; } @@ -77,5 +84,33 @@ public ArrayList getDataset() { public Dataset getSplitedDataset() { return splitedDataset; } + + @Override + public Optional getIndexOfColumn(String columnName) { + for (int i = 0; i < columnNames.length; i++) { + if (columnNames[i].equals(columnName)) { + return Optional.of(i); + } + } + return Optional.empty(); + } + + public String[] getColumnDataOf(String columnName) { + Optional index = getIndexOfColumn(columnName); + if (index.isEmpty()) { + throw new IllegalArgumentException("Column name not found"); + } + int columnIndex = index.get(); + String[] columnData = new String[rawDataset.size()]; + for (int i = 0; i < rawDataset.size(); i++) { + columnData[i] = rawDataset.get(i)[columnIndex]; + } + return columnData; + } + + @Override + public String[] getColumnNames() { + return columnNames; + } } diff --git a/lib/src/main/java/de/edux/data/provider/IDataUtil.java b/lib/src/main/java/de/edux/data/provider/IDataUtil.java index 1f85516..0c5df28 100644 --- a/lib/src/main/java/de/edux/data/provider/IDataUtil.java +++ b/lib/src/main/java/de/edux/data/provider/IDataUtil.java @@ -1,17 +1,17 @@ package de.edux.data.provider; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.ml.nn.network.api.Dataset; import java.io.File; import java.util.List; public interface IDataUtil { - List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords); + List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHeadline, boolean shuffle, EIncompleteRecordsHandlerStrategy IncompleteRecordHandlerStrategy); Dataset split(List dataset, double trainTestSplitRatio); double[][] getInputs(List dataset); double[][] getTargets(List dataset); - } diff --git a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java index 6b14067..2c16dc7 100644 --- a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java +++ b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java @@ -13,13 +13,12 @@ public class CSVIDataReader implements IDataReader { - public List readFile(File file, char separator) { + public List readFile(File file, char separator ) { CSVParser customCSVParser = new CSVParserBuilder().withSeparator(separator).build(); List result; try(CSVReader reader = new CSVReaderBuilder( new FileReader(file)) .withCSVParser(customCSVParser) - .withSkipLines(1) .build()){ result = reader.readAll(); } catch (CsvException | IOException e) { diff --git a/lib/src/test/java/de/edux/data/handler/AverageFillIncompleteRecordHandlerTest.java b/lib/src/test/java/de/edux/data/handler/AverageFillIncompleteRecordHandlerTest.java new file mode 100644 index 0000000..eb9681d --- /dev/null +++ b/lib/src/test/java/de/edux/data/handler/AverageFillIncompleteRecordHandlerTest.java @@ -0,0 +1,76 @@ +package de.edux.data.handler; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class AverageFillIncompleteRecordHandlerTest { + + private List dataset; + + private IIncompleteRecordsHandler incompleteRecordHandler; + + @BeforeEach + void initializeList() { + dataset = new ArrayList<>(); + incompleteRecordHandler = + EIncompleteRecordsHandlerStrategy.FILL_RECORDS_WITH_AVERAGE.getHandler(); + } + + @Test + void dropRecordsWithIncompleteCategoricalFeature() { + + this.dataset.add(new String[] {"A", "1", "A"}); + this.dataset.add(new String[] {"", "2", ""}); + this.dataset.add(new String[] {"C", "", "C"}); + this.dataset.add(new String[] {"D", "3", ""}); + this.dataset.add(new String[] {"E", "4", "E"}); + + assertAll( + () -> assertEquals(3, incompleteRecordHandler.getCleanedDataset(dataset).size()), + () -> + assertEquals( + 2.5, Double.valueOf(incompleteRecordHandler.getCleanedDataset(dataset).get(1)[1]))); + } + + @Test + void testThrowRuntimeExceptionForDroppingMoreThanHalfOfOriginalDataset() { + + this.dataset.add(new String[] {"", "1", "A"}); + this.dataset.add(new String[] {"B", "2", "B"}); + this.dataset.add(new String[] {"C", "3", "C"}); + this.dataset.add(new String[] {"D", "4", ""}); + this.dataset.add(new String[] {"", "5", "E"}); + + assertThrows(RuntimeException.class, () -> incompleteRecordHandler.getCleanedDataset(dataset)); + } + + @Test + void testThrowRuntimeExceptionForZeroValidNumericalFeatures() { + + this.dataset.add(new String[] {"A", "", "A"}); + this.dataset.add(new String[] {"B", "", "B"}); + this.dataset.add(new String[] {"C", "1", "C"}); + this.dataset.add(new String[] {"D", "", "D"}); + this.dataset.add(new String[] {"E", "", "E"}); + + assertThrows(RuntimeException.class, () -> incompleteRecordHandler.getCleanedDataset(dataset)); + } + + @Test + void testThrowRuntimeExceptionForAtLeastOneFullValidRecord() { + + this.dataset.add(new String[] {"", "1", "A"}); + this.dataset.add(new String[] {"B", "2", ""}); + this.dataset.add(new String[] {"", "", "C"}); + this.dataset.add(new String[] {"D", "3", ""}); + this.dataset.add(new String[] {"", "4", "E"}); + + assertThrows(RuntimeException.class, () -> incompleteRecordHandler.getCleanedDataset(dataset)); + } +} diff --git a/lib/src/test/java/de/edux/data/handler/DropIncompleteRecordHandlerTest.java b/lib/src/test/java/de/edux/data/handler/DropIncompleteRecordHandlerTest.java new file mode 100644 index 0000000..92e0072 --- /dev/null +++ b/lib/src/test/java/de/edux/data/handler/DropIncompleteRecordHandlerTest.java @@ -0,0 +1,70 @@ +package de.edux.data.handler; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class DropIncompleteRecordHandlerTest { + private List dataset; + + private IIncompleteRecordsHandler incompleteRecordHandler; + + @BeforeEach + void initializeList() { + dataset = new ArrayList<>(); + incompleteRecordHandler = EIncompleteRecordsHandlerStrategy.DROP_RECORDS.getHandler(); + } + + @Test + void testDropZeroIncompleteResults() { + + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + + assertEquals(5, incompleteRecordHandler.getCleanedDataset(dataset).size()); + } + + @Test + void testDropOneIncompleteResult() { + + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + + assertEquals(4, incompleteRecordHandler.getCleanedDataset(dataset).size()); + } + + @Test + void testDropTwoIncompleteResult() { + + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "", "C"}); + this.dataset.add(new String[] {"A", "", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"A", "B", "C"}); + + assertEquals(3, incompleteRecordHandler.getCleanedDataset(dataset).size()); + } + + @Test + void testThrowRuntimeExceptionForDroppingMoreThanHalfOfOriginalDataset() { + + this.dataset.add(new String[] {"A", "B", "C"}); + this.dataset.add(new String[] {"", "B", "C"}); + this.dataset.add(new String[] {"A", "", "C"}); + this.dataset.add(new String[] {"A", "B", ""}); + this.dataset.add(new String[] {"A", "B", "C"}); + + assertThrows(RuntimeException.class, () -> incompleteRecordHandler.getCleanedDataset(dataset)); + } +} diff --git a/lib/src/test/java/de/edux/data/handler/DropIncompleteRecordsHandlerTest.java b/lib/src/test/java/de/edux/data/handler/DropIncompleteRecordsHandlerTest.java new file mode 100644 index 0000000..f062957 --- /dev/null +++ b/lib/src/test/java/de/edux/data/handler/DropIncompleteRecordsHandlerTest.java @@ -0,0 +1,35 @@ +package de.edux.data.handler; + +import de.edux.data.provider.SeabornDataProcessor; +import de.edux.data.provider.SeabornProvider; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.net.URL; +import java.util.Optional; + +class DropIncompleteRecordsHandlerTest { + private static final boolean SHUFFLE = true; + private static final boolean SKIP_HEADLINE = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; + private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; + private SeabornProvider seabornProvider; + + @Test + void shouldReturnColumnData() { + URL url = DropIncompleteRecordsHandlerTest.class.getClassLoader().getResource(CSV_FILE_PATH); + if (url == null) { + throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); + } + File csvFile = new File(url.getPath()); + var seabornDataProcessor = new SeabornDataProcessor(); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', true, true, INCOMPLETE_RECORD_HANDLER_STRATEGY); + seabornDataProcessor.normalize(dataset); + Optional indexOfSpecies = seabornDataProcessor.getIndexOfColumn("species"); + String[] speciesData = seabornDataProcessor.getColumnDataOf("species"); + + assert indexOfSpecies.isPresent(); + assert speciesData.length > 0; + } +} \ No newline at end of file diff --git a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java index a0b3af0..55acaa0 100644 --- a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java +++ b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java @@ -1,5 +1,6 @@ package de.edux.data.provider; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.data.reader.CSVIDataReader; import de.edux.ml.nn.network.api.Dataset; import org.junit.jupiter.api.BeforeEach; @@ -13,6 +14,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -95,8 +97,7 @@ void testLoadTDataSetWithoutNormalizationAndShuffling() { when(csvDataReader.readFile(any(), anyChar())).thenReturn(csvLine); - List result = dataProcessor.loadDataSetFromCSV(dummyFile, separator, false, false, false); - + List result = dataProcessor.loadDataSetFromCSV(dummyFile, separator, false, false, EIncompleteRecordsHandlerStrategy.DO_NOT_HANDLE); assertEquals(2, result.size(), "Dataset size should be 2"); } @@ -146,6 +147,11 @@ public double[][] getTestLabels() { public double[][] getTestFeatures() { return new double[0][]; } + + @Override + public Optional getIndexOfColumn(String columnName) { + return Optional.empty(); + } }; } diff --git a/lib/src/test/java/de/edux/data/provider/SeabornProvider.java b/lib/src/test/java/de/edux/data/provider/SeabornProvider.java index de98c57..a9927c3 100644 --- a/lib/src/test/java/de/edux/data/provider/SeabornProvider.java +++ b/lib/src/test/java/de/edux/data/provider/SeabornProvider.java @@ -54,6 +54,10 @@ private double[][] featuresOf(List data) { for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); + if (p == null){ + continue; + /* throw new IllegalArgumentException("Missed value in dataset, try to use Imputation methods");*/ + } features[i][0] = p.billLengthMm(); features[i][1] = p.billDepthMm(); features[i][2] = p.flipperLengthMm(); diff --git a/lib/src/test/java/de/edux/ml/RandomForestTest.java b/lib/src/test/java/de/edux/ml/RandomForestTest.java index a2a6c4c..e530fc0 100644 --- a/lib/src/test/java/de/edux/ml/RandomForestTest.java +++ b/lib/src/test/java/de/edux/ml/RandomForestTest.java @@ -1,7 +1,7 @@ package de.edux.ml; import de.edux.api.Classifier; -import de.edux.data.provider.Penguin; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.data.provider.SeabornDataProcessor; import de.edux.data.provider.SeabornProvider; import de.edux.ml.randomforest.RandomForest; @@ -10,14 +10,13 @@ import java.io.File; import java.net.URL; -import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; class RandomForestTest { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; private static SeabornProvider seabornProvider; @@ -29,7 +28,7 @@ static void setup() { } File csvFile = new File(url.getPath()); var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY); var splitedDataset = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); seabornProvider = new SeabornProvider(dataset, splitedDataset.trainData(), splitedDataset.testData()); } diff --git a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java index 537f19c..ce9c1cb 100644 --- a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java +++ b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java @@ -1,5 +1,6 @@ package de.edux.ml.decisiontree; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.data.provider.Penguin; import de.edux.data.provider.SeabornDataProcessor; import de.edux.data.provider.SeabornProvider; @@ -10,14 +11,13 @@ import java.io.File; import java.net.URL; -import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; class DecisionTreeTest { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; private static SeabornProvider seabornProvider; @@ -29,7 +29,7 @@ static void setup() { } File csvFile = new File(url.getPath()); var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY); Dataset splitedDataset = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); seabornProvider = new SeabornProvider(dataset, splitedDataset.trainData(), splitedDataset.testData()); } diff --git a/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java b/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java index b4d4763..8d2c82f 100644 --- a/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java +++ b/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java @@ -1,5 +1,6 @@ package de.edux.ml.nn.network; +import de.edux.data.handler.EIncompleteRecordsHandlerStrategy; import de.edux.data.provider.SeabornDataProcessor; import de.edux.data.provider.SeabornProvider; import de.edux.functions.activation.ActivationFunction; @@ -17,8 +18,8 @@ class MultilayerPerceptronTest { private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final boolean SKIP_HEADLINE = true; + private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS; private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; private SeabornProvider seabornProvider; @@ -31,10 +32,11 @@ void setUp() { } File csvFile = new File(url.getPath()); var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', true, true, INCOMPLETE_RECORD_HANDLER_STRATEGY); + seabornDataProcessor.normalize(dataset); var trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.trainData(), trainTestSplittedList.testData()); - + System.out.println("SeabornProvider loaded"); } @RepeatedTest(3)