Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fear(#23): Replace filterIncompleteRecords boolean with Imputation Enum for Enhanced Data Handling #61

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions example/src/main/java/de/example/benchmark/Benchmark.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package de.example.benchmark;

import de.edux.api.Classifier;
import de.edux.data.handler.EIncompleteRecordsHandlerStrategy;
import de.edux.functions.activation.ActivationFunction;
import de.edux.functions.initialization.Initialization;
import de.edux.functions.loss.LossFunction;
Expand All @@ -27,7 +28,7 @@
public class Benchmark {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv");
private double[][] trainFeatures;
Expand Down Expand Up @@ -127,7 +128,7 @@ private void updateMLP(double[][] testFeatures, double[][] testLabels) {

private void initFeaturesAndLabels() {
var seabornDataProcessor = new SeabornDataProcessor();
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY);
seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO);

trainFeatures = seabornDataProcessor.getTrainFeatures();
Expand Down
6 changes: 3 additions & 3 deletions example/src/main/java/de/example/knn/KnnIrisExample.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package de.example.knn;

import de.edux.api.Classifier;
import de.edux.data.handler.EIncompleteRecordsHandlerStrategy;
import de.edux.ml.knn.KnnClassifier;
import de.edux.ml.nn.network.api.Dataset;
import de.example.data.iris.Iris;
import de.example.data.iris.IrisDataProcessor;

Expand All @@ -17,13 +17,13 @@
public class KnnIrisExample {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv");

public static void main(String[] args) {
var irisDataProcessor = new IrisDataProcessor();
List<Iris> data = irisDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<Iris> data = irisDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY);
irisDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO);

Classifier knn = new KnnClassifier(2);
Expand Down
5 changes: 3 additions & 2 deletions example/src/main/java/de/example/knn/KnnSeabornExample.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package de.example.knn;

import de.edux.api.Classifier;
import de.edux.data.handler.EIncompleteRecordsHandlerStrategy;
import de.edux.ml.knn.KnnClassifier;
import de.edux.ml.nn.network.api.Dataset;
import de.example.data.seaborn.Penguin;
Expand All @@ -17,13 +18,13 @@
public class KnnSeabornExample {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv");

public static void main(String[] args) {
var seabornDataProcessor = new SeabornDataProcessor();
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY);

Dataset<Penguin> dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO);
var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import de.edux.api.Classifier;
import de.edux.data.handler.EIncompleteRecordsHandlerStrategy;
import de.edux.functions.activation.ActivationFunction;
import de.edux.functions.initialization.Initialization;
import de.edux.functions.loss.LossFunction;
Expand All @@ -18,13 +19,13 @@
public class MultilayerPerceptronSeabornExample {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv");

public static void main(String[] args) {
var seabornDataProcessor = new SeabornDataProcessor();
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY);

Dataset<Penguin> dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO);
var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package de.edux.data.handler;

import java.util.ArrayList;
import java.util.List;

public class AverageFillIncompleteRecordsHandler implements IIncompleteRecordsHandler {
@Override
public List<String[]> getCleanedDataset(List<String[]> dataset) {
List<String> typeOfFeatures = getFeatureTypes(dataset);
List<String[]> cleanedDataset =
dropRecordsWithIncompleteCategoricalFeature(dataset, typeOfFeatures);

return averageFillRecordsWithIncompleteNumericalFeature(cleanedDataset, typeOfFeatures);
}

private List<String[]> averageFillRecordsWithIncompleteNumericalFeature(
List<String[]> dataset, List<String> typeOfFeatures) {
for (int columnIndex = 0; columnIndex < typeOfFeatures.size(); columnIndex++) {
int validFeatureCount = 0;
double sum = 0;
double average;

if (typeOfFeatures.get(columnIndex).equals("numerical")) {
for (String[] record : dataset) {
if (isCompleteFeature(record[columnIndex])) {
validFeatureCount++;
sum += Double.parseDouble(record[columnIndex]);
}
}

if (validFeatureCount < dataset.size() * 0.5) {
throw new RuntimeException(
"Less than 50% of the records will be used to calculate the fill values. "
+ "Consider using another IncompleteRecordsHandlerStrategy or handle this exception.");
}

average = sum / validFeatureCount;
for (String[] record : dataset) {
if (!isCompleteFeature(record[columnIndex])) {
record[columnIndex] = String.valueOf(average);
}
}
}
}

return dataset;
}

private List<String[]> dropRecordsWithIncompleteCategoricalFeature(
List<String[]> dataset, List<String> typeOfFeatures) {
List<String[]> cleanedDataset = dataset;

for (int columnIndex = 0; columnIndex < typeOfFeatures.size(); columnIndex++) {
if (typeOfFeatures.get(columnIndex).equals("categorical")) {
int columnIndexFin = columnIndex;
cleanedDataset =
cleanedDataset.stream()
.filter(record -> isCompleteFeature(record[columnIndexFin]))
.toList();
}
}

if (cleanedDataset.size() < dataset.size() * 0.5) {
throw new RuntimeException(
"More than 50% of the records will be dropped with this IncompleteRecordsHandlerStrategy. "
+ "Consider using another IncompleteRecordsHandlerStrategy or handle this exception.");
}

return cleanedDataset;
}

private List<String> getFeatureTypes(List<String[]> dataset) {
List<String> featureTypes = new ArrayList<>();
for (String[] record : dataset) {
if (containsIncompleteFeature(record)) {
continue;
}
for (String feature : record) {
if (isNumeric(feature)) {
featureTypes.add("numerical");
} else {
featureTypes.add("categorical");
}
}
break;
}

if (featureTypes.isEmpty()) {
throw new RuntimeException("At least one full record needed with valid features");
}
return featureTypes;
}

private boolean isNumeric(String feature) {
return feature.matches("-?\\d+(\\.\\d+)?");
}

private boolean isCompleteFeature(String feature) {
return !feature.isBlank();
}

private boolean containsIncompleteFeature(String[] record) {
for (String feature : record) {
if (feature.isBlank()) {
return true;
}
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package de.edux.data.handler;

import java.util.List;

public class DoNotHandleIncompleteRecords implements IIncompleteRecordsHandler {
@Override
public List<String[]> getCleanedDataset(List<String[]> dataset) {
return dataset;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package de.edux.data.handler;

import java.util.ArrayList;
import java.util.List;

public class DropIncompleteRecordsHandler implements IIncompleteRecordsHandler {
@Override
public List<String[]> getCleanedDataset(List<String[]> dataset) {
List<String[]> filteredList =
dataset.stream().filter(this::containsOnlyCompletedFeatures).toList();

if (filteredList.size() < dataset.size() * 0.5) {
throw new RuntimeException(
"More than 50% of the records will be dropped with this IncompleteRecordsHandlerStrategy. "
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

+ "Consider using another IncompleteRecordsHandlerStrategy or handle this exception.");
}

List<String[]> cleanedDataset = new ArrayList<>();
for (String[] item : filteredList) {
cleanedDataset.add(item);
}
return cleanedDataset;
}

private boolean containsOnlyCompletedFeatures(String[] record) {
for (String feature : record) {
if (feature.isBlank()) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.edux.data.handler;

public enum EIncompleteRecordsHandlerStrategy {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Java World wie never prefix enums with 'E'. As in isssue#23 described you need name it "Imputation" here.

Imputation .DROP_RECORDS....

DO_NOT_HANDLE(new DoNotHandleIncompleteRecords()),
DROP_RECORDS(new DropIncompleteRecordsHandler()),
FILL_RECORDS_WITH_AVERAGE(new AverageFillIncompleteRecordsHandler());

private final IIncompleteRecordsHandler incompleteRecordHandler;

EIncompleteRecordsHandlerStrategy(IIncompleteRecordsHandler incompleteRecordHandler) {
this.incompleteRecordHandler = incompleteRecordHandler;
}

public IIncompleteRecordsHandler getHandler() {
return this.incompleteRecordHandler;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package de.edux.data.handler;

import java.util.List;

public interface IIncompleteRecordsHandler {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IImputationHandler

List<String[]> getCleanedDataset(List<String[]> dataset);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package de.edux.data.provider;

import java.util.List;
import java.util.Optional;

public abstract class DataPostProcessor<T> {
public abstract void normalize(List<T> rowDataset);
Expand All @@ -21,4 +22,10 @@ public abstract class DataPostProcessor<T> {

public abstract double[][] getTestFeatures();

public abstract Optional<Integer> getIndexOfColumn(String columnName);

public abstract String[] getColumnDataOf(String columnName);

public abstract String[] getColumnNames();

}
Loading