Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PredictionLoggerEvaluator #285

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject, CapabilitiesHandler {
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject, CapabilitiesHandler, AutoCloseable {

/**
* Resets this evaluator. It must be similar to
Expand Down Expand Up @@ -66,4 +66,8 @@ default ImmutableCapabilities defineImmutableCapabilities() {
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
}

@Override
default void close() throws Exception {
// By default an evaluator does nothing when closed.
}
}
160 changes: 160 additions & 0 deletions moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package moa.evaluation;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

import com.github.javacliparser.FileOption;
import com.github.javacliparser.FlagOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Prediction;

import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.AbstractOptionHandler;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;

public class PredictionLoggerEvaluator extends AbstractOptionHandler
implements ClassificationPerformanceEvaluator {

private static final long serialVersionUID = 1L;

private OutputStreamWriter writer;
private int index = 0;

public FileOption outputPredictionFileOption = new FileOption("output", 'o',
"A file to write comma separated predictions to.", null, "csv.gzip", true);

public FlagOption overwrite = new FlagOption("overwrite", 'f', "Overwrite existing file.");

public ClassOption wrappedEvaluatorOption = new ClassOption("evaluator", 'e',
"Classification performance evaluation method.", ClassificationPerformanceEvaluator.class,
"BasicClassificationPerformanceEvaluator");

public FlagOption probabilities = new FlagOption("probabilities", 'p',
"Log probabilities instead of raw predictions.");

public FlagOption uncompressed = new FlagOption("uncompressed", 'u',
"The output file should be saved uncompressed.");

private ClassificationPerformanceEvaluator wrappedEvaluator;

@Override
public String getPurposeString() {
return "Log raw predictions and probabilities to a CSV file, and evaluate using a wrapped evaluator.";
}

@Override
public void addResult(Example<Instance> example, double[] classVotes) {
Instance instance = example.getData();
int predictedClass = Utils.maxIndex(classVotes);
double normalizingFactor = Arrays.stream(classVotes).sum();
int numClasses = instance.numClasses();

if (normalizingFactor == 0) {
normalizingFactor = 1;
}
try {
// If this is the first result, write the header to the top of the file
if (index == 0)
writeHeader(numClasses);


// Add row to CSV file
if (instance.classIsMissing() == true)
{
writer.write(String.format("?,%d,", predictedClass));
}
else
{
int trueClass = (int) instance.classValue();
writer.write(String.format("%d,%d,", trueClass, predictedClass));
}

if (probabilities.isSet()) {
for (int i = 0; i < numClasses; i++) {
double probability = 0.0;
if (i < classVotes.length){
probability = classVotes[i] / normalizingFactor;
}
writer.write(String.format("%.2f,", probability));
}
}

writer.write("\n");
} catch (Exception e) {
throw new RuntimeException(e);
}

// Pass result to wrapped evaluator
wrappedEvaluator.addResult(example, classVotes);
index ++;
}

@Override
public void addResult(Example<Instance> testInst, Prediction prediction) {
throw new UnsupportedOperationException("Not implemented");
}

@Override
protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
wrappedEvaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(wrappedEvaluatorOption);
try {
File file = outputPredictionFileOption.getFile();
if (file.exists() && !overwrite.isSet()) {
throw new RuntimeException(
"File already exists: " + file.getAbsolutePath()
+ ". MOA doesn't want to overwrite it.");
}
if (uncompressed.isSet())
writer = new OutputStreamWriter(new FileOutputStream(file));
else
writer = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(file)));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private void writeHeader(int numClasses) throws IOException {
writer.write("true_class,class_prediction,");
if (probabilities.isSet()) {
for (int i = 0; i < numClasses; i++) {
writer.write(String.format("class_probability_%d,", i));
}
}
writer.write("\n");
}

@Override
public void close() throws Exception {
writer.close();
}

@Override
public void reset() {
wrappedEvaluator.reset();
}

@Override
public Measurement[] getPerformanceMeasurements() {
return wrappedEvaluator.getPerformanceMeasurements();
}

@Override
public void getDescription(StringBuilder sb, int indent) {
sb.append(getPurposeString());
}

@Override
public ImmutableCapabilities defineImmutableCapabilities() {
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
}
}
8 changes: 5 additions & 3 deletions moa/src/main/java/moa/tasks/EvaluateConceptDrift.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
/* if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
}*/
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}
}
5 changes: 5 additions & 0 deletions moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}

Expand Down
31 changes: 4 additions & 27 deletions moa/src/main/java/moa/tasks/EvaluateModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,35 +107,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
long instancesProcessed = 0;
monitor.setCurrentActivity("Evaluating model...", -1.0);

//File for output predictions
File outputPredictionFile = this.outputPredictionFileOption.getFile();
PrintStream outputPredictionResultStream = null;
if (outputPredictionFile != null) {
try {
if (outputPredictionFile.exists()) {
outputPredictionResultStream = new PrintStream(
new FileOutputStream(outputPredictionFile, true), true);
} else {
outputPredictionResultStream = new PrintStream(
new FileOutputStream(outputPredictionFile), true);
}
} catch (Exception ex) {
throw new RuntimeException(
"Unable to open prediction result file: " + outputPredictionFile, ex);
}
}
while (stream.hasMoreInstances()
&& ((maxInstances < 0) || (instancesProcessed < maxInstances))) {
Example testInst = (Example) stream.nextInstance();//.copy();
int trueClass = (int) ((Instance) testInst.getData()).classValue();
//testInst.setClassMissing();
double[] prediction = model.getVotesForInstance(testInst);
//evaluator.addClassificationAttempt(trueClass, prediction, testInst
// .weight());
if (outputPredictionFile != null) {
outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," +(
((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass));
}
evaluator.addResult(testInst, prediction);
instancesProcessed++;

Expand Down Expand Up @@ -169,8 +144,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
}
}
}
if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}
Expand Down
5 changes: 5 additions & 0 deletions moa/src/main/java/moa/tasks/EvaluateModelMultiLabel.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return new LearningEvaluation(evaluator, model);
}
}
5 changes: 5 additions & 0 deletions moa/src/main/java/moa/tasks/EvaluateModelMultiTarget.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return new LearningEvaluation(evaluator, model);
}
}
5 changes: 5 additions & 0 deletions moa/src/main/java/moa/tasks/EvaluateModelRegression.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return new LearningEvaluation(evaluator, model);
}
}
5 changes: 5 additions & 0 deletions moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}

Expand Down
44 changes: 11 additions & 33 deletions moa/src/main/java/moa/tasks/EvaluatePrequential.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ public class EvaluatePrequential extends ClassificationMainTask implements Capab

@Override
public String getPurposeString() {
return "Evaluates a classifier on a stream by testing then training with each example in sequence.";
return
"Evaluates a classifier on a stream by testing then training with each example in sequence."
+ "\n`outputPredictionFile` has been replaced with the `PredictionLoggerEvaluator`";
}

private static final long serialVersionUID = 1L;
Expand Down Expand Up @@ -97,9 +99,6 @@ public String getPurposeString() {
public FileOption dumpFileOption = new FileOption("dumpFile", 'd',
"File to append intermediate csv results to.", null, "csv", true);

public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o',
"File to append output predictions to.", null, "pred", true);

//New for prequential method DEPRECATED
public IntOption widthOption = new IntOption("width",
'w', "Size of Window", 1000);
Expand Down Expand Up @@ -168,23 +167,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
"Unable to open immediate result file: " + dumpFile, ex);
}
}
//File for output predictions
File outputPredictionFile = this.outputPredictionFileOption.getFile();
PrintStream outputPredictionResultStream = null;
if (outputPredictionFile != null) {
try {
if (outputPredictionFile.exists()) {
outputPredictionResultStream = new PrintStream(
new FileOutputStream(outputPredictionFile, true), true);
} else {
outputPredictionResultStream = new PrintStream(
new FileOutputStream(outputPredictionFile), true);
}
} catch (Exception ex) {
throw new RuntimeException(
"Unable to open prediction result file: " + outputPredictionFile, ex);
}
}
boolean firstDump = true;
boolean preciseCPUTiming = TimingUtils.enablePreciseTiming();
long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
Expand All @@ -194,20 +176,14 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
&& ((maxInstances < 0) || (instancesProcessed < maxInstances))
&& ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) {
Example trainInst = stream.nextInstance();
Example testInst = (Example) trainInst; //.copy();
//testInst.setClassMissing();
double[] prediction = learner.getVotesForInstance(testInst);
// Output prediction
if (outputPredictionFile != null) {
int trueClass = (int) ((Instance) trainInst.getData()).classValue();
outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," + (
((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass));
}
Example testInst = (Example) trainInst;

//evaluator.addClassificationAttempt(trueClass, prediction, testInst.weight());
double[] prediction = learner.getVotesForInstance(testInst);
evaluator.addResult(testInst, prediction);

learner.trainOnInstance(trainInst);
instancesProcessed++;

if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0
|| stream.hasMoreInstances() == false) {
long evaluateTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
Expand Down Expand Up @@ -267,8 +243,10 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
if (outputPredictionResultStream != null) {
outputPredictionResultStream.close();
try {
evaluator.close();
} catch (Exception ex) {
throw new RuntimeException("Exception closing evaluator", ex);
}
return learningCurve;
}
Expand Down
Loading