Skip to content

Commit

Permalink
Added training metrics and multithreaded training metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
trunerd committed Jul 3, 2016
1 parent e32fcf6 commit 64f9365
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ public class RecognitionMetric {
private final List<RecognitionScore> potentialMisRecognized;
private int totalTemplates;
private List<Exception> recognitionException;
private double averageRecognitionTime;

public RecognitionMetric(double averageScore, double averageScoreOfCorrect, int numberCorrect,
List<RecognitionScore> nonRecognizedIds, List<RecognitionScore> potentialMisRecognized,
int totalTemplates, List<Exception> recognitionException) {
int totalTemplates, List<Exception> recognitionException, double averageRecognitionTime) {
this.averageScore = averageScore;
this.averageScoreOfCorrect = averageScoreOfCorrect;
this.numberCorrect = numberCorrect;
this.nonRecognizedIds = nonRecognizedIds;
this.potentialMisRecognized = potentialMisRecognized;
this.totalTemplates = totalTemplates;
this.recognitionException = recognitionException;
this.averageRecognitionTime = averageRecognitionTime;
}

public double getAverageScore() {
Expand Down Expand Up @@ -59,7 +61,10 @@ public String toString() {
"\n\tFalse Positives:" +
"\n\t\tNumber of False Positives " + potentialMisRecognized.size() +
"\n\t\tNonRecognized Percentage: " + (((double) potentialMisRecognized.size())/ ((double) totalTemplates)) +
"\n\tNumber fo recognition exceptions: " + recognitionException.size();
"\n\tNumber fo recognition exceptions: " + recognitionException.size() +
"\n\tTime:" +
"\n\t\tRecognitionTimeNanos: " + averageRecognitionTime +
"\n\t\tRecognitionTimeMillis: " + (averageRecognitionTime / 1000000.);
}

public int getTotalTemplates() {
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/coursesketch/recognition/test/RecognitionScore.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class RecognitionScore {
private boolean notRecognized;
private List<Sketch.SrlInterpretation> recognizedInterpretations;
private Sketch.SrlInterpretation correctInterpretations;
/**
* The time the recognition took in nano seconds
*/
private long recognitionTime;

public RecognitionScore(RecognitionInterface recognitionSystem, String templateId) {
this.recognitionSystem = recognitionSystem;
Expand Down Expand Up @@ -91,4 +95,12 @@ public Sketch.SrlInterpretation getCorrectInterpretations() {
public List<Sketch.SrlInterpretation> getRecognizedInterpretations() {
return recognizedInterpretations;
}

public void setRecognitionTime(long recognitionTime) {
this.recognitionTime = recognitionTime;
}

public long getRecognitionTime() {
return recognitionTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,61 @@ public class RecognitionScoreMetrics {
*/
private static final Logger LOG = LoggerFactory.getLogger(RecognitionScoreMetrics.class);

private List<Exception> exceptions;
private final String simpleName;
private List<TrainingScore> trainingScores;
private List<RecognitionScore> recognitionScores;
private RecognitionMetric recognitionMetric;
private TrainingMetric trainingMetric;

public RecognitionScoreMetrics(List<Exception> exceptions, List<RecognitionScore> recognitionScores) {
this.exceptions = exceptions;
public RecognitionScoreMetrics(String simpleName, List<TrainingScore> trainingScores, List<RecognitionScore> recognitionScores) {
this.simpleName = simpleName;
this.trainingScores = trainingScores;
this.recognitionScores = recognitionScores;
}

public int getNumberOfTrainingException() {
return exceptions.size();
}

public List<Exception> getTrainingExceptions() {
return exceptions;
}

public double getAverageScore() {
double score = 0;
for (RecognitionScore recognitionScore : recognitionScores) {
score += (recognitionScore.getScoreValue() / ((double) recognitionScores.size()));
}
return score;
public List<TrainingScore> getTrainingScores() {
return trainingScores;
}

public List<RecognitionScore> getScores() {
return recognitionScores;
}

public void computeMetrics() {
computeMetrics(recognitionScores);
public void computeRecognitionMetrics() {
recognitionMetric = computeRecognitionMetrics(recognitionScores);
trainingMetric = computeTrainingMetrics(trainingScores);
}

public RecognitionMetric computeMetrics(List<RecognitionScore> recognitionScores) {
public TrainingMetric computeTrainingMetrics(List<TrainingScore> trainingScores) {
List<Exception> trainingException = new ArrayList<>();
double numTemplates = 0;
double averageTrainingTime = 0;
for (TrainingScore trainingScore : trainingScores) {
if (trainingScore == null) {
LOG.debug("RECOGNITION SCORE IS NULL");
continue;
}
numTemplates++;
if (trainingScore.hasException()) {
trainingException.add(trainingScore.getException());
}
}
for (TrainingScore trainingScore : trainingScores) {
if (trainingScore == null) {
LOG.debug("RECOGNITION SCORE IS NULL");
continue;
}
averageTrainingTime += ((double) trainingScore.getTrainingTime()) / numTemplates;
}
return new TrainingMetric((int) numTemplates, trainingException, averageTrainingTime);
}

public RecognitionMetric computeRecognitionMetrics(List<RecognitionScore> recognitionScores) {
double averageScore = 0;
double averageScoreOfCorrect = 0;
double numberCorrect = 0;
double numTemplates = 0;
double averageRecognitionTime = 0;
List<RecognitionScore> nonRecognizedIds = new ArrayList<>();
List<RecognitionScore> potentialMisRecognized = new ArrayList<>();
List<Exception> recognitionException = new ArrayList<>();
Expand All @@ -78,10 +96,30 @@ public RecognitionMetric computeMetrics(List<RecognitionScore> recognitionScores
recognitionException.add(recognitionScore.getException());
}
}
for (RecognitionScore recognitionScore : recognitionScores) {
if (recognitionScore == null) {
LOG.debug("RECOGNITION SCORE IS NULL");
continue;
}
averageRecognitionTime += ((double) recognitionScore.getRecognitionTime()) / numTemplates;
}
averageScore /= numTemplates;
averageScoreOfCorrect /= numberCorrect;
LOG.debug("Finished Computing metrics");
return new RecognitionMetric(averageScore, averageScoreOfCorrect,
(int) numberCorrect, nonRecognizedIds, potentialMisRecognized, (int) numTemplates, recognitionException);
(int) numberCorrect, nonRecognizedIds, potentialMisRecognized,
(int) numTemplates, recognitionException, averageRecognitionTime);
}

public String getSimpleName() {
return simpleName;
}

public RecognitionMetric getRecognitionMetric() {
return recognitionMetric;
}

public TrainingMetric getTrainingMetric() {
return trainingMetric;
}
}
67 changes: 49 additions & 18 deletions src/main/java/coursesketch/recognition/test/RecognitionTesting.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ public List<RecognitionScoreMetrics> testAgainstInterpretation(Sketch.SrlInterpr
/**
* This uses cross validation to test against templates.
*/
private List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.RecognitionTemplate> allTemplates)
public List<RecognitionScoreMetrics> testAgainstTemplates(List<Sketch.RecognitionTemplate> allTemplates)
throws TemplateException {

List<Sketch.RecognitionTemplate> testTemplates = splitTrainingAndTest(allTemplates);

Map<RecognitionInterface, List<Exception>> exceptions = trainAgainstTemplates(allTemplates);
Map<RecognitionInterface, List<TrainingScore>> trainingScores = trainAgainstTemplates(allTemplates);

Map<RecognitionInterface, List<RecognitionScore>> recognitionScore =
recognizeAgainstTemplates(testTemplates);

List<RecognitionScoreMetrics> metrics = new ArrayList<>();
for (RecognitionInterface recognitionSystem : recognitionSystems) {
metrics.add(new RecognitionScoreMetrics(exceptions.get(recognitionSystem),
metrics.add(new RecognitionScoreMetrics(recognitionSystem.getClass().getSimpleName(), trainingScores.get(recognitionSystem),
recognitionScore.get(recognitionSystem)));
}
return metrics;
Expand All @@ -77,7 +77,7 @@ private Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTempla
Map<RecognitionInterface, List<RecognitionScore>> scoreMap = new HashMap<>();

// For the specific number of threads needed
executor = Executors.newFixedThreadPool(Math.min(MAX_THREADS, Math.max(1, testTemplates.size() / 10)));
executor = Executors.newFixedThreadPool(Math.min(MAX_THREADS, Math.max(1, testTemplates.size() / 20)));

LOG.debug("Running recognition test for {} templates", testTemplates.size());
int percent = (int) Math.round(Math.max(1.0, testTemplates.size() / 100.0));
Expand All @@ -93,9 +93,12 @@ private Map<RecognitionInterface, List<RecognitionScore>> recognizeAgainstTempla
@Override
public Object call() throws Exception {
RecognitionScore score = new RecognitionScore(recognitionSystem, testTemplate.getTemplateId());
long startTime = System.nanoTime();
try {
List<Sketch.SrlInterpretation>
recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate);
long endTime = System.nanoTime();
score.setRecognitionTime(endTime - startTime);
if (recognize == null) {
score.setFailed(new NullPointerException("List of returned interpretations is null"));
recognitionScoreList.add(score);
Expand Down Expand Up @@ -132,32 +135,60 @@ public Object call() throws Exception {
return scoreMap;
}

public Map<RecognitionInterface, List<Exception>> trainAgainstTemplates(List<Sketch.RecognitionTemplate> templates) {
Map<RecognitionInterface, List<Exception>> exceptionMap = new HashMap<>();
public Map<RecognitionInterface, List<TrainingScore>> trainAgainstTemplates(List<Sketch.RecognitionTemplate> templates) {
Map<RecognitionInterface, List<TrainingScore>> scoreMap = new HashMap<>();

executor = Executors.newFixedThreadPool(Math.min(MAX_THREADS, Math.max(1, templates.size() / 5)));

LOG.debug("Running recognition training for {} templates", templates.size());
int percent = (int) Math.round(Math.max(1.0, templates.size() / 10.0));
for (RecognitionInterface recognitionSystem : recognitionSystems) {
List<Exception> trainingExceptions = new ArrayList<>();
exceptionMap.put(recognitionSystem, trainingExceptions);
List<TrainingScore> trainingScores = new ArrayList<>();
scoreMap.put(recognitionSystem, trainingScores);
int counter = 0;
LOG.debug("training recognition system: {}", recognitionSystem.getClass().getSimpleName());
List<Future> taskFutures = new ArrayList<>();
for (Sketch.RecognitionTemplate template : templates) {
try {
recognitionSystem.trainTemplate(template);
} catch (Exception e) {
LOG.error("Exception occured while training", e);
trainingExceptions.add(
new RecognitionTestException("Error with training template " + template.getTemplateId(),
final int thisCount = counter;
taskFutures.add(executor.submit(new Callable() {
@Override
public Object call() throws Exception {
TrainingScore score = new TrainingScore();
long startTime = System.nanoTime();
try {
recognitionSystem.trainTemplate(template);
} catch (Exception e) {
LOG.error("Exception occured while training", e);
score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(),
e, recognitionSystem));
}
}
long endTime = System.nanoTime();
score.setTrainingTime(endTime - startTime);
trainingScores.add(score);

if (thisCount % percent == 0) {
LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount);
}
return null;
}
}));
counter++;
if (counter % percent == 0) {
LOG.debug("gone through {} sketches, {} left", counter, templates.size() - counter);
}

LOG.debug("Waiting for all tasks to finish");
// Waits for the executor to finish
for (Future taskFuture : taskFutures) {
try {
taskFuture.get();
} catch (InterruptedException e) {
LOG.debug("INTERUPTIONS EXCEPTION", e);
} catch (ExecutionException e) {
LOG.debug("EXECUTION EXCEPTION", e);
}
}
LOG.debug("All trainings tasks have finished");
}
return exceptionMap;
return scoreMap;
}

private void generateScore(RecognitionScore score,
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/coursesketch/recognition/test/TrainingMetric.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package coursesketch.recognition.test;

import java.util.List;

/**
* Created by David Windows on 7/3/2016.
*/
public class TrainingMetric {
private int totalTemplates;
private final List<Exception> exceptionList;
private final double averageTrainingTime;

public TrainingMetric(int numTemplates, List<Exception> exceptionList, double averageTrainingTime) {
totalTemplates = numTemplates;

this.exceptionList = exceptionList;
this.averageTrainingTime = averageTrainingTime;
}

public int getNumberOfExceptions() {
return exceptionList.size();
}

public List<Exception> getExceptionList() {
return exceptionList;
}

public double getAverageTrainingTime() {
return averageTrainingTime;
}

public String toString() {
return "Metrics: " +
"\n\tTotal Number of templates:" + totalTemplates +
"\n\tNumber fo recognition exceptions: " + exceptionList.size() +
"\n\tTime:" +
"\n\t\tRecognitionTimeNanos: " + averageTrainingTime +
"\n\t\tRecognitionTimeMillis: " + (averageTrainingTime / 1000000.);
}
}
30 changes: 30 additions & 0 deletions src/main/java/coursesketch/recognition/test/TrainingScore.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package coursesketch.recognition.test;

/**
* Created by David Windows on 7/3/2016.
*/
public class TrainingScore {
private RecognitionTestException exception;
private long trainingTime;

public void addException(RecognitionTestException e) {

exception = e;
}

public void setTrainingTime(long trainingTime) {
this.trainingTime = trainingTime;
}

public long getTrainingTime() {
return trainingTime;
}

public boolean hasException() {
return exception != null;
}

public RecognitionTestException getException() {
return exception;
}
}

0 comments on commit 64f9365

Please sign in to comment.