Skip to content

Commit

Permalink
Add PredictionLoggerEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
tachyonicClock committed Sep 21, 2023
1 parent 73cb667 commit 9a33c54
Show file tree
Hide file tree
Showing 16 changed files with 229 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
package moa.evaluation;

import com.yahoo.labs.samoa.instances.Instance;
import moa.MOAObject;
import moa.core.Example;
import moa.core.Measurement;

public interface ClassificationPerformanceEvaluator extends LearningPerformanceEvaluator<Example<Instance>> {

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject, CapabilitiesHandler {
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject, CapabilitiesHandler, AutoCloseable {

/**
* Resets this evaluator. It must be similar to
Expand Down Expand Up @@ -66,4 +66,8 @@ default ImmutableCapabilities defineImmutableCapabilities() {
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
}

@Override
default void close() throws Exception {
// By default an evaluator does nothing when closed.
}
}
160 changes: 160 additions & 0 deletions moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package moa.evaluation;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

import com.github.javacliparser.FileOption;
import com.github.javacliparser.FlagOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Prediction;

import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.AbstractOptionHandler;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;

public class PredictionLoggerEvaluator extends AbstractOptionHandler
implements ClassificationPerformanceEvaluator {

private static final long serialVersionUID = 1L;

private OutputStreamWriter writer;
private int index = 0;

public FileOption csvFileOption = new FileOption("predictionLog", 'o',
"A file to write comma separated values to.", null, "csv.gzip", true);

public FlagOption overwrite = new FlagOption("overwrite", 'f', "Overwrite existing file.");

public ClassOption wrappedEvaluatorOption = new ClassOption("evaluator", 'e',
"Classification performance evaluation method.", ClassificationPerformanceEvaluator.class,
"BasicClassificationPerformanceEvaluator");

public FlagOption probabilities = new FlagOption("probabilities", 'p',
"Log probabilities instead of raw predictions.");

public FlagOption uncompressed = new FlagOption("uncompressed", 'u',
"The output file should be saved uncompressed.");

private ClassificationPerformanceEvaluator wrappedEvaluator;

@Override
public String getPurposeString() {
return "Log raw predictions and probabilities to a CSV file, and evaluate using a wrapped evaluator.";
}

@Override
public void addResult(Example<Instance> example, double[] classVotes) {
Instance instance = example.getData();
int predictedClass = Utils.maxIndex(classVotes);
double normalizingFactor = Arrays.stream(classVotes).sum();
int numClasses = instance.numClasses();

if (normalizingFactor == 0) {
normalizingFactor = 1;
}
try {
// If this is the first result, write the header to the top of the file
if (index == 0)
writeHeader(numClasses);


// Add row to CSV file
if (instance.classIsMissing() == true)
{
writer.write(String.format("?,%d,", predictedClass));
}
else
{
int trueClass = (int) instance.classValue();
writer.write(String.format("%d,%d,", trueClass, predictedClass));
}

if (probabilities.isSet()) {
for (int i = 0; i < numClasses; i++) {
double probability = 0.0;
if (i < classVotes.length){
probability = classVotes[i] / normalizingFactor;
}
writer.write(String.format("%.2f,", probability));
}
}

writer.write("\n");
} catch (Exception e) {
throw new RuntimeException(e);
}

// Pass result to wrapped evaluator
wrappedEvaluator.addResult(example, classVotes);
index ++;
}

@Override
public void addResult(Example<Instance> testInst, Prediction prediction) {
addResult(testInst, prediction.getVotes());
}

@Override
protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
wrappedEvaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(wrappedEvaluatorOption);
try {
File file = csvFileOption.getFile();
if (file.exists() && !overwrite.isSet()) {
throw new RuntimeException(
"File already exists: " + file.getAbsolutePath()
+ ". MOA doesn't want to overwrite it.");
}
if (uncompressed.isSet())
writer = new OutputStreamWriter(new FileOutputStream(file));
else
writer = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(file)));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private void writeHeader(int numClasses) throws IOException {
writer.write("true_class,class_prediction,");
if (probabilities.isSet()) {
for (int i = 0; i < numClasses; i++) {
writer.write(String.format("class_probability_%d,", i));
}
}
writer.write("\n");
}

@Override
public void close() throws Exception {
writer.close();
}

@Override
public void reset() {
wrappedEvaluator.reset();
}

@Override
public Measurement[] getPerformanceMeasurements() {
return wrappedEvaluator.getPerformanceMeasurements();
}

@Override
public void getDescription(StringBuilder sb, int indent) {
sb.append(getPurposeString());
}

@Override
public ImmutableCapabilities defineImmutableCapabilities() {
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
}
}
5 changes: 5 additions & 0 deletions moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Example;
import moa.core.Measurement;
Expand All @@ -40,7 +39,6 @@
import com.github.javacliparser.IntOption;
import moa.streams.ExampleStream;
import moa.streams.InstanceStream;
import com.yahoo.labs.samoa.instances.Instance;

/**
* Task for evaluating a classifier on a stream by testing then training with each example in sequence.
Expand Down Expand Up @@ -217,6 +215,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}

Expand Down
36 changes: 4 additions & 32 deletions moa/src/main/java/moa/tasks/EvaluateModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
*/
package moa.tasks;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import com.github.javacliparser.FileOption;
import com.github.javacliparser.IntOption;
import moa.capabilities.CapabilitiesHandler;
Expand All @@ -32,15 +29,13 @@
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.evaluation.LearningEvaluation;
import moa.evaluation.LearningPerformanceEvaluator;
import moa.evaluation.preview.LearningCurve;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.streams.ExampleStream;
import moa.streams.InstanceStream;
import com.yahoo.labs.samoa.instances.Instance;

/**
* Task for evaluating a static model on a stream.
Expand Down Expand Up @@ -107,35 +102,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
long instancesProcessed = 0;
monitor.setCurrentActivity("Evaluating model...", -1.0);

//File for output predictions
File outputPredictionFile = this.outputPredictionFileOption.getFile();
PrintStream outputPredictionResultStream = null;
if (outputPredictionFile != null) {
try {
if (outputPredictionFile.exists()) {
outputPredictionResultStream = new PrintStream(
new FileOutputStream(outputPredictionFile, true), true);
} else {
outputPredictionResultStream = new PrintStream(
new FileOutputStream(outputPredictionFile), true);
}
} catch (Exception ex) {
throw new RuntimeException(
"Unable to open prediction result file: " + outputPredictionFile, ex);
}
}
while (stream.hasMoreInstances()
&& ((maxInstances < 0) || (instancesProcessed < maxInstances))) {
Example testInst = (Example) stream.nextInstance();//.copy();
int trueClass = (int) ((Instance) testInst.getData()).classValue();
//testInst.setClassMissing();
double[] prediction = model.getVotesForInstance(testInst);
//evaluator.addClassificationAttempt(trueClass, prediction, testInst
// .weight());
if (outputPredictionFile != null) {
outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," +(
((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass));
}
evaluator.addResult(testInst, prediction);
instancesProcessed++;

Expand Down Expand Up @@ -169,8 +139,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
}
}
}
if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}
Expand Down
53 changes: 5 additions & 48 deletions moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.github.javacliparser.IntOption;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Example;
import moa.core.Measurement;
Expand Down Expand Up @@ -140,12 +139,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
}
testStream = new CachedInstancesStream(testInstances);
} else {
//testStream = (InstanceStream) stream.copy();
testStream = stream;
/*monitor.setCurrentActivity("Skipping test examples...", -1.0);
for (int i = 0; i < testSize; i++) {
stream.nextInstance();
}*/
}
instancesProcessed = 0;
TimingUtils.enablePreciseTiming();
Expand Down Expand Up @@ -191,10 +185,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
break;
}
Example testInst = (Example) testStream.nextInstance(); //.copy();
double trueClass = ((Instance) testInst.getData()).classValue();
//testInst.setClassMissing();
double[] prediction = learner.getVotesForInstance(testInst);
//testInst.setClassValue(trueClass);
evaluator.addResult(testInst, prediction);
testInstancesProcessed++;
if (testInstancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
Expand Down Expand Up @@ -242,49 +233,15 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (monitor.resultPreviewRequested()) {
monitor.setLatestResultPreview(learningCurve.copy());
}
// if (learner instanceof HoeffdingTree
// || learner instanceof HoeffdingOptionTree) {
// int numActiveNodes = (int) Measurement.getMeasurementNamed(
// "active learning leaves",
// modelMeasurements).getValue();
// // exit if tree frozen
// if (numActiveNodes < 1) {
// break;
// }
// int numNodes = (int) Measurement.getMeasurementNamed(
// "tree size (nodes)", modelMeasurements)
// .getValue();
// if (numNodes == lastNumNodes) {
// noGrowthCount++;
// } else {
// noGrowthCount = 0;
// }
// lastNumNodes = numNodes;
// } else if (learner instanceof OzaBoost || learner instanceof
// OzaBag) {
// double numActiveNodes = Measurement.getMeasurementNamed(
// "[avg] active learning leaves",
// modelMeasurements).getValue();
// // exit if all trees frozen
// if (numActiveNodes == 0.0) {
// break;
// }
// int numNodes = (int) (Measurement.getMeasurementNamed(
// "[avg] tree size (nodes)",
// learner.getModelMeasurements()).getValue() * Measurement
// .getMeasurementNamed("ensemble size",
// modelMeasurements).getValue());
// if (numNodes == lastNumNodes) {
// noGrowthCount++;
// } else {
// noGrowthCount = 0;
// }
// lastNumNodes = numNodes;
// }
}
if (immediateResultStream != null) {
immediateResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}

Expand Down
Loading

0 comments on commit 9a33c54

Please sign in to comment.