Skip to content

Commit

Permalink
Added a lot more to the training metrics to make them give better det…
Browse files Browse the repository at this point in the history
…ails and allow for subclassing better
  • Loading branch information
dtracers committed Jul 5, 2016
1 parent ac307c5 commit 8824e1b
Show file tree
Hide file tree
Showing 12 changed files with 464 additions and 210 deletions.
77 changes: 0 additions & 77 deletions src/main/java/coursesketch/recognition/test/RecognitionMetric.java

This file was deleted.

209 changes: 133 additions & 76 deletions src/main/java/coursesketch/recognition/test/RecognitionTesting.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import coursesketch.recognition.framework.TemplateDatabaseInterface;
import coursesketch.recognition.framework.exceptions.RecognitionException;
import coursesketch.recognition.framework.exceptions.TemplateException;
import coursesketch.recognition.test.converter.ScoreMetricsConverter;
import coursesketch.recognition.test.converter.ScoreMetricsConverterFactory;
import coursesketch.recognition.test.score.RecognitionScore;
import coursesketch.recognition.test.score.RecognitionScoreFactory;
import coursesketch.recognition.test.score.TrainingScore;
import coursesketch.recognition.test.score.TrainingScoreFactory;
import protobuf.srl.sketch.Sketch;

import java.util.ArrayList;
Expand Down Expand Up @@ -31,6 +37,9 @@ public class RecognitionTesting {
private int MAX_THREADS = 500;

ExecutorService executor;
protected RecognitionScoreFactory recognitionFactory = new DefaultRecognitionScoreFactory();
protected TrainingScoreFactory trainingFactory = new DefaultTrainingScoreFactory();
private ScoreMetricsConverterFactory converterFactory = new DefaultScoreMetricsConverterFactory();

/**
*
Expand All @@ -43,19 +52,31 @@ public RecognitionTesting(TemplateDatabaseInterface databaseInterface, Recogniti
this.recognitionSystems = recognitionSystems;
}

public List<RecognitionScoreMetrics> testAgainstAllTemplates() throws TemplateException {
public void setRecognitionScoreFactory(RecognitionScoreFactory recognitionFactory) {
this.recognitionFactory = recognitionFactory;
}

public void setTrainingScoreFactory(TrainingScoreFactory trainingFactory) {
this.trainingFactory = trainingFactory;
}

public void setScoreMetricsConverterFactory(ScoreMetricsConverterFactory converterFactory) {
this.converterFactory = converterFactory;
}

public List<ScoreMetricsConverter> testAgainstAllTemplates() throws TemplateException {
return testAgainstTemplates(databaseInterface.getAllTemplates());
}

public List<RecognitionScoreMetrics> testAgainstInterpretation(Sketch.SrlInterpretation interpretation)
public List<ScoreMetricsConverter> testAgainstInterpretation(Sketch.SrlInterpretation interpretation)
throws TemplateException {
return testAgainstTemplates(databaseInterface.getTemplate(interpretation));
}

/**
* This uses cross validation to test against templates.
*/
public List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.RecognitionTemplate> allTemplates)
public List<ScoreMetricsConverter> testAgainstTemplates(List<Sketch.RecognitionTemplate> allTemplates)
throws TemplateException {

List<Sketch.RecognitionTemplate> testTemplates = splitTrainingAndTest(allTemplates);
Expand All @@ -65,14 +86,29 @@ public List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.Recognitio
Map<RecognitionInterface, List<RecognitionScore>> recognitionScore =
recognizeAgainstTemplates(testTemplates);

List<RecognitionScoreMetrics> metrics = new ArrayList<>();
List<ScoreMetricsConverter> metrics = new ArrayList<>();
for (RecognitionInterface recognitionSystem : recognitionSystems) {
metrics.add(new RecognitionScoreMetrics(recognitionSystem.getClass().getSimpleName(), trainingScores.get(recognitionSystem),
recognitionScore.get(recognitionSystem)));
ScoreMetricsConverter scoreMetricsConverter = converterFactory.getScoreMetricsConverter(recognitionSystem,
trainingScores.get(recognitionSystem), recognitionScore.get(recognitionSystem));
scoreMetricsConverter.computeRecognitionMetrics();
metrics.add(scoreMetricsConverter);
}
return metrics;
}

protected List<Sketch.SrlInterpretation> testTemplate(Sketch.RecognitionTemplate testTemplate,
RecognitionInterface recognitionSystem,
RecognitionScore score) {
List<Sketch.SrlInterpretation> recognize = null;
try {
recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate);
} catch (Exception e) {
score.setNotRecognized(true);
score.setFailed(e);
}
return recognize;
}

public Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTemplates(
List<Sketch.RecognitionTemplate> testTemplates) {
Map<RecognitionInterface, List<RecognitionScore>> scoreMap = new HashMap<>();
Expand All @@ -90,51 +126,54 @@ public Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTemplat
List<Future> taskFutures = new ArrayList<>();
for (Sketch.RecognitionTemplate testTemplate : testTemplates) {
final int thisCount = counter;
taskFutures.add(executor.submit(new Callable(){
@Override
public Object call() throws Exception {
RecognitionScore score = new RecognitionScore(recognitionSystem, testTemplate.getTemplateId());
long startTime = System.nanoTime();
try {
List<Sketch.SrlInterpretation>
recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate);
long endTime = System.nanoTime();
score.setRecognitionTime(endTime - startTime);
if (recognize == null) {
score.setFailed(new NullPointerException("List of returned interpretations is null"));
recognitionScoreList.add(score);
return null;
}
generateScore(score, recognize, testTemplate.getInterpretation());
} catch (Exception e) {
score.setFailed(e);
}
recognitionScoreList.add(score);
if (thisCount % percent == 0) {
LOG.debug("gone through {} sketches, {} left", thisCount, testTemplates.size() - thisCount);
}
return null;
taskFutures.add(executor.submit((Callable) () -> {
RecognitionScore score = recognitionFactory.createRecognitionScore(recognitionSystem,
testTemplate.getTemplateId());
long startTime = System.nanoTime();
final List<Sketch.SrlInterpretation> interpretations =
testTemplate(testTemplate, recognitionSystem, score);
generateScore(score, interpretations, testTemplate.getInterpretation());
long endTime = System.nanoTime();
score.setRecognitionTime(endTime - startTime);
recognitionScoreList.add(score);
if (thisCount % percent == 0) {
LOG.debug("gone through {} sketches, {} left", thisCount, testTemplates.size() - thisCount);
}
return null;
}));
counter++;
}

LOG.debug("Waiting for all tasks to finish");
// Waits for the executor to finish
for (Future taskFuture : taskFutures) {
try {
taskFuture.get();
} catch (InterruptedException e) {
LOG.debug("INTERUPTIONS EXCEPTION", e);
} catch (ExecutionException e) {
LOG.debug("EXECUTION EXCEPTION", e);
}
}
waitForFutures(taskFutures);
LOG.debug("All recognition testing tasks have finished");
}
return scoreMap;
}

private void waitForFutures(List<Future> taskFutures) {
for (Future taskFuture : taskFutures) {
try {
taskFuture.get();
} catch (InterruptedException e) {
LOG.debug("INTERUPTIONS EXCEPTION", e);
} catch (ExecutionException e) {
LOG.debug("EXECUTION EXCEPTION", e);
}
}
}

protected void trainSystem(Sketch.RecognitionTemplate template, RecognitionInterface recognitionSystem,
TrainingScore score) {
try {
recognitionSystem.trainTemplate(template);
} catch (Exception e) {
score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(),
e, recognitionSystem));
}
}

public Map<RecognitionInterface, List<TrainingScore>> trainAgainstTemplates(List<Sketch.RecognitionTemplate> templates) {
Map<RecognitionInterface, List<TrainingScore>> scoreMap = new HashMap<>();

Expand All @@ -150,63 +189,57 @@ public Map<RecognitionInterface, List<TrainingScore>> trainAgainstTemplates(List
List<Future> taskFutures = new ArrayList<>();
for (Sketch.RecognitionTemplate template : templates) {
final int thisCount = counter;
taskFutures.add(executor.submit(new Callable() {
@Override
public Object call() throws Exception {
TrainingScore score = new TrainingScore();
long startTime = System.nanoTime();
try {
recognitionSystem.trainTemplate(template);
} catch (Exception e) {
score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(),
e, recognitionSystem));
}
long endTime = System.nanoTime();
score.setTrainingTime(endTime - startTime);
trainingScores.add(score);

if (thisCount % percent == 0) {
LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount);
}
return null;
taskFutures.add(executor.submit((Callable) () -> {
TrainingScore score = trainingFactory.createTrainingScore(recognitionSystem, template.getTemplateId());
long startTime = System.nanoTime();
trainSystem(template, recognitionSystem, score);
long endTime = System.nanoTime();
score.setTrainingTime(endTime - startTime);
trainingScores.add(score);
if (thisCount % percent == 0) {
LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount);
}
return null;
}));
counter++;
}

LOG.debug("Waiting for all tasks to finish");
// Waits for the executor to finish
for (Future taskFuture : taskFutures) {
try {
taskFuture.get();
} catch (InterruptedException e) {
LOG.debug("INTERUPTIONS EXCEPTION", e);
} catch (ExecutionException e) {
LOG.debug("EXECUTION EXCEPTION", e);
}
}
try {
recognitionSystem.finishTraining();
} catch (RecognitionException e) {
LOG.debug("EXCEPTION WHEN TRAINING", e);
}
waitForFutures(taskFutures);
finishTraining(recognitionSystem);
LOG.debug("All trainings tasks have finished");
}
return scoreMap;
}

private void generateScore(RecognitionScore score,
protected void finishTraining(RecognitionInterface recognitionSystem) {
try {
recognitionSystem.finishTraining();
} catch (RecognitionException e) {
LOG.debug("EXCEPTION WHEN TRAINING", e);
}
}

protected void generateScore(RecognitionScore score,
List<Sketch.SrlInterpretation> recognize, Sketch.SrlInterpretation interpretation) {
if (recognize == null) {
score.setNotRecognized(true);
score.setFailed(new NullPointerException("List of returned interpretations is null"));
return;
}
double scoreValue = 1;
int topGuesses = Math.min(5, recognize.size());
int subtractAmount = 1/topGuesses;
score.setRecognizedInterpretations(recognize);
score.setCorrectInterpretations(interpretation);
for (int i = 0; i < topGuesses; i++) {
if (recognize.get(i).getLabel().equals(interpretation.getLabel())) {
score.setRecognized(true);
// We won't consider it recognized if it has no confidence in its values
if (recognize.get(i).getLabel().equals(interpretation.getLabel())
&& recognize.get(i).getConfidence() > 0) {
score.setRecognized(i);
score.setScoreValue(scoreValue * recognize.get(i).getConfidence());
return ;
return;
}
if (i == 0) {
score.setPotentialMissRecognized(true);
Expand Down Expand Up @@ -234,4 +267,28 @@ private List<Sketch.RecognitionTemplate> splitTrainingAndTest(List<Sketch.Recogn
LOG.debug("TrainingSet: {}, TestingSet: {}", allTemplates.size(), testTemplates.size());
return testTemplates;
}

private static final class DefaultRecognitionScoreFactory implements RecognitionScoreFactory {
@Override
public RecognitionScore createRecognitionScore(RecognitionInterface recognitionSystem, String templateId) {
return new RecognitionScore(recognitionSystem, templateId);
}
}

private static final class DefaultTrainingScoreFactory implements TrainingScoreFactory {
@Override
public TrainingScore createTrainingScore(RecognitionInterface recognitionSystem, String templateId) {
return new TrainingScore(recognitionSystem, templateId);
}
}

private static final class DefaultScoreMetricsConverterFactory implements ScoreMetricsConverterFactory {

@Override
public ScoreMetricsConverter getScoreMetricsConverter(RecognitionInterface recognitionSystem,
List<TrainingScore> trainingScores, List<RecognitionScore> recognitionScores) {
return new ScoreMetricsConverter(recognitionSystem.getClass().getSimpleName(),
trainingScores, recognitionScores);
}
}
}
Loading

0 comments on commit 8824e1b

Please sign in to comment.