diff --git a/src/main/java/coursesketch/recognition/test/RecognitionMetric.java b/src/main/java/coursesketch/recognition/test/RecognitionMetric.java deleted file mode 100644 index ae77df8..0000000 --- a/src/main/java/coursesketch/recognition/test/RecognitionMetric.java +++ /dev/null @@ -1,77 +0,0 @@ -package coursesketch.recognition.test; - -import java.util.List; - -/** - * Created by David Windows on 7/2/2016. - */ -public class RecognitionMetric { - private final double averageScore; - private final double averageScoreOfCorrect; - private final int numberCorrect; - private final List nonRecognizedIds; - private final List potentialMisRecognized; - private int totalTemplates; - private List recognitionException; - private double averageRecognitionTime; - - public RecognitionMetric(double averageScore, double averageScoreOfCorrect, int numberCorrect, - List nonRecognizedIds, List potentialMisRecognized, - int totalTemplates, List recognitionException, double averageRecognitionTime) { - this.averageScore = averageScore; - this.averageScoreOfCorrect = averageScoreOfCorrect; - this.numberCorrect = numberCorrect; - this.nonRecognizedIds = nonRecognizedIds; - this.potentialMisRecognized = potentialMisRecognized; - this.totalTemplates = totalTemplates; - this.recognitionException = recognitionException; - this.averageRecognitionTime = averageRecognitionTime; - } - - public double getAverageScore() { - return averageScore; - } - - public double getAverageScoreOfCorrect() { - return averageScoreOfCorrect; - } - - public int getNumberCorrect() { - return numberCorrect; - } - - public List getNonRecognizedIds() { - return nonRecognizedIds; - } - - public List getPotentialMisRecognized() { - return potentialMisRecognized; - } - - public String toString() { - return "Metrics: " + - "\n\tTotal Number of templates:" + totalTemplates + - "\n\tTotal Average Score:" + averageScore + - "\n\tCorrectness:" + - "\n\t\tNumber Correct: " + numberCorrect + - "\n\t\tCorrect Percentage: " + (((double) numberCorrect) / ((double) totalTemplates)) + - "\n\tWrongness:" + - "\n\t\tNumber Incorrect " + nonRecognizedIds.size() + - "\n\t\tNonRecognized Percentage: " + (((double) nonRecognizedIds.size())/ ((double) totalTemplates)) + - "\n\tFalse Positives:" + - "\n\t\tNumber of False Positives " + potentialMisRecognized.size() + - "\n\t\tNonRecognized Percentage: " + (((double) potentialMisRecognized.size())/ ((double) totalTemplates)) + - "\n\tNumber fo recognition exceptions: " + recognitionException.size() + - "\n\tTime:" + - "\n\t\tRecognitionTimeNanos: " + averageRecognitionTime + - "\n\t\tRecognitionTimeMillis: " + (averageRecognitionTime / 1000000.); - } - - public int getTotalTemplates() { - return totalTemplates; - } - - public List getRecognitionExceptions() { - return recognitionException; - } -} diff --git a/src/main/java/coursesketch/recognition/test/RecognitionTesting.java b/src/main/java/coursesketch/recognition/test/RecognitionTesting.java index 8a4b64e..c87fbf9 100644 --- a/src/main/java/coursesketch/recognition/test/RecognitionTesting.java +++ b/src/main/java/coursesketch/recognition/test/RecognitionTesting.java @@ -4,6 +4,12 @@ import coursesketch.recognition.framework.TemplateDatabaseInterface; import coursesketch.recognition.framework.exceptions.RecognitionException; import coursesketch.recognition.framework.exceptions.TemplateException; +import coursesketch.recognition.test.converter.ScoreMetricsConverter; +import coursesketch.recognition.test.converter.ScoreMetricsConverterFactory; +import coursesketch.recognition.test.score.RecognitionScore; +import coursesketch.recognition.test.score.RecognitionScoreFactory; +import coursesketch.recognition.test.score.TrainingScore; +import coursesketch.recognition.test.score.TrainingScoreFactory; import protobuf.srl.sketch.Sketch; import java.util.ArrayList; @@ -31,6 +37,9 @@ public class RecognitionTesting { private int MAX_THREADS = 500; ExecutorService executor; + protected RecognitionScoreFactory recognitionFactory = new DefaultRecognitionScoreFactory(); + protected TrainingScoreFactory trainingFactory = new DefaultTrainingScoreFactory(); + private ScoreMetricsConverterFactory converterFactory = new DefaultScoreMetricsConverterFactory(); /** * @@ -43,11 +52,23 @@ public RecognitionTesting(TemplateDatabaseInterface databaseInterface, Recogniti this.recognitionSystems = recognitionSystems; } - public List testAgainstAllTemplates() throws TemplateException { + public void setRecognitionScoreFactory(RecognitionScoreFactory recognitionFactory) { + this.recognitionFactory = recognitionFactory; + } + + public void setTrainingScoreFactory(TrainingScoreFactory trainingFactory) { + this.trainingFactory = trainingFactory; + } + + public void setScoreMetricsConverterFactory(ScoreMetricsConverterFactory converterFactory) { + this.converterFactory = converterFactory; + } + + public List testAgainstAllTemplates() throws TemplateException { return testAgainstTemplates(databaseInterface.getAllTemplates()); } - public List testAgainstInterpretation(Sketch.SrlInterpretation interpretation) + public List testAgainstInterpretation(Sketch.SrlInterpretation interpretation) throws TemplateException { return testAgainstTemplates(databaseInterface.getTemplate(interpretation)); } @@ -55,7 +76,7 @@ public List testAgainstInterpretation(Sketch.SrlInterpr /** * This uses cross validation to test against templates. */ - public List testAgainstTemplates(List allTemplates) + public List testAgainstTemplates(List allTemplates) throws TemplateException { List testTemplates = splitTrainingAndTest(allTemplates); @@ -65,14 +86,29 @@ public List testAgainstTemplates(List> recognitionScore = recognizeAgainstTemplates(testTemplates); - List metrics = new ArrayList<>(); + List metrics = new ArrayList<>(); for (RecognitionInterface recognitionSystem : recognitionSystems) { - metrics.add(new RecognitionScoreMetrics(recognitionSystem.getClass().getSimpleName(), trainingScores.get(recognitionSystem), - recognitionScore.get(recognitionSystem))); + ScoreMetricsConverter scoreMetricsConverter = converterFactory.getScoreMetricsConverter(recognitionSystem, + trainingScores.get(recognitionSystem), recognitionScore.get(recognitionSystem)); + scoreMetricsConverter.computeRecognitionMetrics(); + metrics.add(scoreMetricsConverter); } return metrics; } + protected List testTemplate(Sketch.RecognitionTemplate testTemplate, + RecognitionInterface recognitionSystem, + RecognitionScore score) { + List recognize = null; + try { + recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate); + } catch (Exception e) { + score.setNotRecognized(true); + score.setFailed(e); + } + return recognize; + } + public Map> recognizeAgainstTemplates( List testTemplates) { Map> scoreMap = new HashMap<>(); @@ -90,51 +126,54 @@ public Map> recognizeAgainstTemplat List taskFutures = new ArrayList<>(); for (Sketch.RecognitionTemplate testTemplate : testTemplates) { final int thisCount = counter; - taskFutures.add(executor.submit(new Callable(){ - @Override - public Object call() throws Exception { - RecognitionScore score = new RecognitionScore(recognitionSystem, testTemplate.getTemplateId()); - long startTime = System.nanoTime(); - try { - List - recognize = recognitionSystem.recognize(testTemplate.getTemplateId(), testTemplate); - long endTime = System.nanoTime(); - score.setRecognitionTime(endTime - startTime); - if (recognize == null) { - score.setFailed(new NullPointerException("List of returned interpretations is null")); - recognitionScoreList.add(score); - return null; - } - generateScore(score, recognize, testTemplate.getInterpretation()); - } catch (Exception e) { - score.setFailed(e); - } - recognitionScoreList.add(score); - if (thisCount % percent == 0) { - LOG.debug("gone through {} sketches, {} left", thisCount, testTemplates.size() - thisCount); - } - return null; + taskFutures.add(executor.submit((Callable) () -> { + RecognitionScore score = recognitionFactory.createRecognitionScore(recognitionSystem, + testTemplate.getTemplateId()); + long startTime = System.nanoTime(); + final List interpretations = + testTemplate(testTemplate, recognitionSystem, score); + generateScore(score, interpretations, testTemplate.getInterpretation()); + long endTime = System.nanoTime(); + score.setRecognitionTime(endTime - startTime); + recognitionScoreList.add(score); + if (thisCount % percent == 0) { + LOG.debug("gone through {} sketches, {} left", thisCount, testTemplates.size() - thisCount); } + return null; })); counter++; } LOG.debug("Waiting for all tasks to finish"); // Waits for the executor to finish - for (Future taskFuture : taskFutures) { - try { - taskFuture.get(); - } catch (InterruptedException e) { - LOG.debug("INTERUPTIONS EXCEPTION", e); - } catch (ExecutionException e) { - LOG.debug("EXECUTION EXCEPTION", e); - } - } + waitForFutures(taskFutures); LOG.debug("All recognition testing tasks have finished"); } return scoreMap; } + private void waitForFutures(List taskFutures) { + for (Future taskFuture : taskFutures) { + try { + taskFuture.get(); + } catch (InterruptedException e) { + LOG.debug("INTERUPTIONS EXCEPTION", e); + } catch (ExecutionException e) { + LOG.debug("EXECUTION EXCEPTION", e); + } + } + } + + protected void trainSystem(Sketch.RecognitionTemplate template, RecognitionInterface recognitionSystem, + TrainingScore score) { + try { + recognitionSystem.trainTemplate(template); + } catch (Exception e) { + score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(), + e, recognitionSystem)); + } + } + public Map> trainAgainstTemplates(List templates) { Map> scoreMap = new HashMap<>(); @@ -150,63 +189,57 @@ public Map> trainAgainstTemplates(List List taskFutures = new ArrayList<>(); for (Sketch.RecognitionTemplate template : templates) { final int thisCount = counter; - taskFutures.add(executor.submit(new Callable() { - @Override - public Object call() throws Exception { - TrainingScore score = new TrainingScore(); - long startTime = System.nanoTime(); - try { - recognitionSystem.trainTemplate(template); - } catch (Exception e) { - score.addException(new RecognitionTestException("Error with training template " + template.getTemplateId(), - e, recognitionSystem)); - } - long endTime = System.nanoTime(); - score.setTrainingTime(endTime - startTime); - trainingScores.add(score); - - if (thisCount % percent == 0) { - LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount); - } - return null; + taskFutures.add(executor.submit((Callable) () -> { + TrainingScore score = trainingFactory.createTrainingScore(recognitionSystem, template.getTemplateId()); + long startTime = System.nanoTime(); + trainSystem(template, recognitionSystem, score); + long endTime = System.nanoTime(); + score.setTrainingTime(endTime - startTime); + trainingScores.add(score); + if (thisCount % percent == 0) { + LOG.debug("gone through {} sketches, {} left", thisCount, templates.size() - thisCount); } + return null; })); counter++; } LOG.debug("Waiting for all tasks to finish"); // Waits for the executor to finish - for (Future taskFuture : taskFutures) { - try { - taskFuture.get(); - } catch (InterruptedException e) { - LOG.debug("INTERUPTIONS EXCEPTION", e); - } catch (ExecutionException e) { - LOG.debug("EXECUTION EXCEPTION", e); - } - } - try { - recognitionSystem.finishTraining(); - } catch (RecognitionException e) { - LOG.debug("EXCEPTION WHEN TRAINING", e); - } + waitForFutures(taskFutures); + finishTraining(recognitionSystem); LOG.debug("All trainings tasks have finished"); } return scoreMap; } - private void generateScore(RecognitionScore score, + protected void finishTraining(RecognitionInterface recognitionSystem) { + try { + recognitionSystem.finishTraining(); + } catch (RecognitionException e) { + LOG.debug("EXCEPTION WHEN TRAINING", e); + } + } + + protected void generateScore(RecognitionScore score, List recognize, Sketch.SrlInterpretation interpretation) { + if (recognize == null) { + score.setNotRecognized(true); + score.setFailed(new NullPointerException("List of returned interpretations is null")); + return; + } double scoreValue = 1; int topGuesses = Math.min(5, recognize.size()); int subtractAmount = 1/topGuesses; score.setRecognizedInterpretations(recognize); score.setCorrectInterpretations(interpretation); for (int i = 0; i < topGuesses; i++) { - if (recognize.get(i).getLabel().equals(interpretation.getLabel())) { - score.setRecognized(true); + // We won't consider it recognized if it has no confidence in its values + if (recognize.get(i).getLabel().equals(interpretation.getLabel()) + && recognize.get(i).getConfidence() > 0) { + score.setRecognized(i); score.setScoreValue(scoreValue * recognize.get(i).getConfidence()); - return ; + return; } if (i == 0) { score.setPotentialMissRecognized(true); @@ -234,4 +267,28 @@ private List splitTrainingAndTest(List trainingScores, List recognitionScores) { + return new ScoreMetricsConverter(recognitionSystem.getClass().getSimpleName(), + trainingScores, recognitionScores); + } + } } diff --git a/src/main/java/coursesketch/recognition/test/TrainingMetric.java b/src/main/java/coursesketch/recognition/test/TrainingMetric.java deleted file mode 100644 index dd48104..0000000 --- a/src/main/java/coursesketch/recognition/test/TrainingMetric.java +++ /dev/null @@ -1,40 +0,0 @@ -package coursesketch.recognition.test; - -import java.util.List; - -/** - * Created by David Windows on 7/3/2016. - */ -public class TrainingMetric { - private int totalTemplates; - private final List exceptionList; - private final double averageTrainingTime; - - public TrainingMetric(int numTemplates, List exceptionList, double averageTrainingTime) { - totalTemplates = numTemplates; - - this.exceptionList = exceptionList; - this.averageTrainingTime = averageTrainingTime; - } - - public int getNumberOfExceptions() { - return exceptionList.size(); - } - - public List getExceptionList() { - return exceptionList; - } - - public double getAverageTrainingTime() { - return averageTrainingTime; - } - - public String toString() { - return "Metrics: " + - "\n\tTotal Number of templates:" + totalTemplates + - "\n\tNumber fo recognition exceptions: " + exceptionList.size() + - "\n\tTime:" + - "\n\t\tRecognitionTimeNanos: " + averageTrainingTime + - "\n\t\tRecognitionTimeMillis: " + (averageTrainingTime / 1000000.); - } -} diff --git a/src/main/java/coursesketch/recognition/test/RecognitionScoreMetrics.java b/src/main/java/coursesketch/recognition/test/converter/ScoreMetricsConverter.java similarity index 64% rename from src/main/java/coursesketch/recognition/test/RecognitionScoreMetrics.java rename to src/main/java/coursesketch/recognition/test/converter/ScoreMetricsConverter.java index 4ae4ae8..20d84a2 100644 --- a/src/main/java/coursesketch/recognition/test/RecognitionScoreMetrics.java +++ b/src/main/java/coursesketch/recognition/test/converter/ScoreMetricsConverter.java @@ -1,5 +1,9 @@ -package coursesketch.recognition.test; +package coursesketch.recognition.test.converter; +import coursesketch.recognition.test.metric.RecognitionMetric; +import coursesketch.recognition.test.metric.TrainingMetric; +import coursesketch.recognition.test.score.RecognitionScore; +import coursesketch.recognition.test.score.TrainingScore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -9,12 +13,12 @@ /** * Created by turnerd on 7/1/16. */ -public class RecognitionScoreMetrics { +public class ScoreMetricsConverter { /** * Declaration and Definition of Logger. */ - private static final Logger LOG = LoggerFactory.getLogger(RecognitionScoreMetrics.class); + private static final Logger LOG = LoggerFactory.getLogger(ScoreMetricsConverter.class); private final String simpleName; private List trainingScores; @@ -22,7 +26,7 @@ public class RecognitionScoreMetrics { private RecognitionMetric recognitionMetric; private TrainingMetric trainingMetric; - public RecognitionScoreMetrics(String simpleName, List trainingScores, List recognitionScores) { + public ScoreMetricsConverter(String simpleName, List trainingScores, List recognitionScores) { this.simpleName = simpleName; this.trainingScores = trainingScores; this.recognitionScores = recognitionScores; @@ -45,6 +49,7 @@ public TrainingMetric computeTrainingMetrics(List trainingScores) List trainingException = new ArrayList<>(); double numTemplates = 0; double averageTrainingTime = 0; + long maxRecognitionTime = 0; for (TrainingScore trainingScore : trainingScores) { if (trainingScore == null) { LOG.debug("RECOGNITION SCORE IS NULL"); @@ -60,17 +65,22 @@ public TrainingMetric computeTrainingMetrics(List trainingScores) LOG.debug("RECOGNITION SCORE IS NULL"); continue; } + maxRecognitionTime = Math.max(maxRecognitionTime, trainingScore.getTrainingTime()); averageTrainingTime += ((double) trainingScore.getTrainingTime()) / numTemplates; } - return new TrainingMetric((int) numTemplates, trainingException, averageTrainingTime); + TrainingMetric metric = new TrainingMetric((int) numTemplates, trainingException, averageTrainingTime); + metric.setMaxTime(maxRecognitionTime); + return metric; } public RecognitionMetric computeRecognitionMetrics(List recognitionScores) { double averageScore = 0; double averageScoreOfCorrect = 0; - double numberCorrect = 0; + double numberCorrectTop5 = 0; double numTemplates = 0; double averageRecognitionTime = 0; + long maxRecognitionTime = 0; + int numberOfTrueRecognition = 0; List nonRecognizedIds = new ArrayList<>(); List potentialMisRecognized = new ArrayList<>(); List recognitionException = new ArrayList<>(); @@ -84,7 +94,10 @@ public RecognitionMetric computeRecognitionMetrics(List recogn averageScore += recognitionScore.getScoreValue(); if (recognitionScore.isRecognized()) { averageScoreOfCorrect += recognitionScore.getScoreValue(); - numberCorrect++; + numberCorrectTop5++; + } + if (recognitionScore.isTrueRecognized()) { + numberOfTrueRecognition++; } if (recognitionScore.isNotRecognized()) { nonRecognizedIds.add(recognitionScore); @@ -102,13 +115,24 @@ public RecognitionMetric computeRecognitionMetrics(List recogn continue; } averageRecognitionTime += ((double) recognitionScore.getRecognitionTime()) / numTemplates; + maxRecognitionTime = Math.max(maxRecognitionTime, recognitionScore.getRecognitionTime()); } averageScore /= numTemplates; - averageScoreOfCorrect /= numberCorrect; + averageScoreOfCorrect /= numberCorrectTop5; LOG.debug("Finished Computing metrics"); - return new RecognitionMetric(averageScore, averageScoreOfCorrect, - (int) numberCorrect, nonRecognizedIds, potentialMisRecognized, - (int) numTemplates, recognitionException, averageRecognitionTime); + RecognitionMetric recognitionMetric = new RecognitionMetric(); + recognitionMetric.setAverageScore(averageScore); + recognitionMetric.setAverageTime(averageRecognitionTime); + recognitionMetric.setAverageScoreOfCorrect(averageScoreOfCorrect); + recognitionMetric.setNumberCorrectTop5((int) numberCorrectTop5); + recognitionMetric.setNumberTrueCorrect(numberOfTrueRecognition); + recognitionMetric.setNonRecognizedIds(nonRecognizedIds); + recognitionMetric.setPotentialMisRecognized(potentialMisRecognized); + recognitionMetric.setTotalTemplates((int) numTemplates); + recognitionMetric.setExceptionList(recognitionException); + recognitionMetric.setAverageTime(averageRecognitionTime); + recognitionMetric.setMaxTime(maxRecognitionTime); + return recognitionMetric; } public String getSimpleName() { diff --git a/src/main/java/coursesketch/recognition/test/converter/ScoreMetricsConverterFactory.java b/src/main/java/coursesketch/recognition/test/converter/ScoreMetricsConverterFactory.java new file mode 100644 index 0000000..384e862 --- /dev/null +++ b/src/main/java/coursesketch/recognition/test/converter/ScoreMetricsConverterFactory.java @@ -0,0 +1,16 @@ +package coursesketch.recognition.test.converter; + +import coursesketch.recognition.framework.RecognitionInterface; +import coursesketch.recognition.test.score.RecognitionScore; +import coursesketch.recognition.test.score.TrainingScore; + +import java.util.List; + +/** + * Created by david on 7/5/16. + */ +public interface ScoreMetricsConverterFactory { + ScoreMetricsConverter getScoreMetricsConverter(RecognitionInterface recognitionSystem, + List trainingScores, + List recognitionScores); +} diff --git a/src/main/java/coursesketch/recognition/test/metric/RecognitionMetric.java b/src/main/java/coursesketch/recognition/test/metric/RecognitionMetric.java new file mode 100644 index 0000000..bbe3d64 --- /dev/null +++ b/src/main/java/coursesketch/recognition/test/metric/RecognitionMetric.java @@ -0,0 +1,143 @@ +package coursesketch.recognition.test.metric; + +import coursesketch.recognition.test.score.RecognitionScore; + +import java.util.List; + +/** + * Created by David Windows on 7/2/2016. + */ +public class RecognitionMetric implements TestingMetric { + private double averageScore; + private double averageScoreOfCorrect; + private int numberCorrect; + private List nonRecognizedIds; + private List potentialMisRecognized; + private int totalTemplates; + private List recognitionException; + private double averageRecognitionTime; + private double precision; + private double recall; + private double fscore; + private int numberTrueCorrect; + private double maxTime; + + public RecognitionMetric() { + } + + public void calculateExtraValues() { + precision = ((double) numberTrueCorrect) / ((double) potentialMisRecognized.size()); + recall = ((double) numberTrueCorrect) / ((double) nonRecognizedIds.size()); + fscore = 2.0 * (precision * recall) / (precision + recall); + } + + public double getAverageScore() { + return averageScore; + } + + public double getAverageScoreOfCorrect() { + return averageScoreOfCorrect; + } + + public int getNumberCorrect() { + return numberCorrect; + } + + public List getNonRecognizedIds() { + return nonRecognizedIds; + } + + public List getPotentialMisRecognized() { + return potentialMisRecognized; + } + + public String toString() { + return "Recognition Metrics: " + + "\n\tScores:" + + "\n\t\tRecall:" + recall + + "\n\t\tPrecision" + precision + + "\n\t\tFScore" + fscore + + "\n\tTotal Number of templates:" + totalTemplates + + "\n\tTotal Average Score:" + averageScore + + "\n\tCorrectness:" + + "\n\t\tNumber Correct: " + numberCorrect + + "\n\t\tCorrect Percentage: " + (((double) numberCorrect) / ((double) totalTemplates)) + + "\n\tWrongness:" + + "\n\t\tNumber Incorrect " + nonRecognizedIds.size() + + "\n\t\tNonRecognized Percentage: " + (((double) nonRecognizedIds.size())/ ((double) totalTemplates)) + + "\n\tFalse Positives:" + + "\n\t\tNumber of False Positives " + potentialMisRecognized.size() + + "\n\t\tNonRecognized Percentage: " + (((double) potentialMisRecognized.size())/ ((double) totalTemplates)) + + "\n\tNumber fo recognition exceptions: " + recognitionException.size() + + "\n\tPerformance:" + + "\n\t\tRecognitionTimeNanos: " + averageRecognitionTime + + "\n\t\tRecognitionTimeMillis: " + (averageRecognitionTime / 1000000.) + + "\n\t\tMax Recognition time: " + maxTime; + } + + public int getTotalTemplates() { + return totalTemplates; + } + + public void setAverageScore(double averageScore) { + this.averageScore = averageScore; + } + + public void setAverageScoreOfCorrect(double averageScoreOfCorrect) { + this.averageScoreOfCorrect = averageScoreOfCorrect; + } + + public void setNumberCorrectTop5(int numberCorrect) { + this.numberCorrect = numberCorrect; + } + + public void setNonRecognizedIds(List nonRecognizedIds) { + this.nonRecognizedIds = nonRecognizedIds; + } + + public void setPotentialMisRecognized(List potentialMisRecognized) { + this.potentialMisRecognized = potentialMisRecognized; + } + + public void setTotalTemplates(int totalTemplates) { + this.totalTemplates = totalTemplates; + } + + public void setExceptionList(List recognitionException) { + this.recognitionException = recognitionException; + } + + @Override + public void setAverageTime(double averageRecognitionTime) { + this.averageRecognitionTime = averageRecognitionTime; + } + + public void setNumberTrueCorrect(int numberTrueCorrect) { + this.numberTrueCorrect = numberTrueCorrect; + } + + @Override + public int getNumberOfExceptions() { + return recognitionException.size(); + } + + @Override + public List getExceptionList() { + return recognitionException; + } + + @Override + public double getAverageTime() { + return averageRecognitionTime; + } + + @Override + public double getMaxTime() { + return maxTime; + } + + @Override + public void setMaxTime(double time) { + this.maxTime = time; + } +} diff --git a/src/main/java/coursesketch/recognition/test/metric/TestingMetric.java b/src/main/java/coursesketch/recognition/test/metric/TestingMetric.java new file mode 100644 index 0000000..7d1d3e4 --- /dev/null +++ b/src/main/java/coursesketch/recognition/test/metric/TestingMetric.java @@ -0,0 +1,24 @@ +package coursesketch.recognition.test.metric; + +import java.util.List; + +/** + * Created by david on 7/4/16. + */ +public interface TestingMetric { + void setAverageTime(double time); + int getNumberOfExceptions(); + + void setExceptionList(List exceptions); + List getExceptionList(); + + double getAverageTime(); + + /** + * Returns the single largest value in metric + * @return + */ + double getMaxTime(); + + void setMaxTime(double time); +} diff --git a/src/main/java/coursesketch/recognition/test/metric/TrainingMetric.java b/src/main/java/coursesketch/recognition/test/metric/TrainingMetric.java new file mode 100644 index 0000000..ebcacf8 --- /dev/null +++ b/src/main/java/coursesketch/recognition/test/metric/TrainingMetric.java @@ -0,0 +1,62 @@ +package coursesketch.recognition.test.metric; + +import java.util.List; + +/** + * Created by David Windows on 7/3/2016. + */ +public class TrainingMetric implements TestingMetric { + private int totalTemplates; + private List exceptionList; + private double averageTrainingTime; + private double maxTime; + + public TrainingMetric(int numTemplates, List exceptionList, double averageTrainingTime) { + totalTemplates = numTemplates; + + this.exceptionList = exceptionList; + this.averageTrainingTime = averageTrainingTime; + } + + @Override + public void setAverageTime(double time) { + averageTrainingTime = time; + } + + public int getNumberOfExceptions() { + return exceptionList.size(); + } + + @Override + public void setExceptionList(List exceptions) { + exceptionList = exceptions; + } + + public List getExceptionList() { + return exceptionList; + } + + @Override + public double getAverageTime() { + return averageTrainingTime; + } + + @Override + public double getMaxTime() { + return maxTime; + } + + @Override + public void setMaxTime(double time) { + this.maxTime = time; + } + + public String toString() { + return "Training Metrics: " + + "\n\tTotal Number of templates:" + totalTemplates + + "\n\tNumber of training exceptions: " + exceptionList.size() + + "\n\tTime:" + + "\n\t\tTrainingTimeNanos: " + averageTrainingTime + + "\n\t\tTrainingTimeMillis: " + (averageTrainingTime / 1000000.); + } +} diff --git a/src/main/java/coursesketch/recognition/test/RecognitionScore.java b/src/main/java/coursesketch/recognition/test/score/RecognitionScore.java similarity index 83% rename from src/main/java/coursesketch/recognition/test/RecognitionScore.java rename to src/main/java/coursesketch/recognition/test/score/RecognitionScore.java index a786ab7..23d7793 100644 --- a/src/main/java/coursesketch/recognition/test/RecognitionScore.java +++ b/src/main/java/coursesketch/recognition/test/score/RecognitionScore.java @@ -1,4 +1,4 @@ -package coursesketch.recognition.test; +package coursesketch.recognition.test.score; import coursesketch.recognition.framework.RecognitionInterface; import coursesketch.recognition.framework.exceptions.RecognitionException; @@ -16,7 +16,7 @@ public class RecognitionScore { private final RecognitionInterface recognitionSystem; private String templateId; private Exception exception; - private boolean recognized; + private int recognizedIndex = -1; private double scoreValue; private boolean potentialMissRecognized; private boolean notRecognized; @@ -36,8 +36,8 @@ public void setFailed(Exception exception) { this.exception = exception; } - public void setRecognized(boolean recognized) { - this.recognized = recognized; + public void setRecognized(int recognizedIndex) { + this.recognizedIndex = recognizedIndex; } public void setScoreValue(double scoreValue) { @@ -65,7 +65,22 @@ public boolean isNotRecognized() { } public boolean isRecognized() { - return recognized; + return recognizedIndex > 0; + } + + public boolean isTrueRecognized() { + return recognizedIndex == 0; + } + + /** + * The order at which it was recognized. + * + * 0 being the best option and the larger the number the worse it is + * -1 is not recognized + * @return + */ + public int getRecognizedIndex() { + return recognizedIndex; } public Exception getException() { diff --git a/src/main/java/coursesketch/recognition/test/score/RecognitionScoreFactory.java b/src/main/java/coursesketch/recognition/test/score/RecognitionScoreFactory.java new file mode 100644 index 0000000..cdc29ba --- /dev/null +++ b/src/main/java/coursesketch/recognition/test/score/RecognitionScoreFactory.java @@ -0,0 +1,10 @@ +package coursesketch.recognition.test.score; + +import coursesketch.recognition.framework.RecognitionInterface; + +/** + * Created by david on 7/5/16. + */ +public interface RecognitionScoreFactory { + public RecognitionScore createRecognitionScore(RecognitionInterface recognitionSystem, String templateId); +} diff --git a/src/main/java/coursesketch/recognition/test/TrainingScore.java b/src/main/java/coursesketch/recognition/test/score/TrainingScore.java similarity index 54% rename from src/main/java/coursesketch/recognition/test/TrainingScore.java rename to src/main/java/coursesketch/recognition/test/score/TrainingScore.java index ff6c66d..cae0bf5 100644 --- a/src/main/java/coursesketch/recognition/test/TrainingScore.java +++ b/src/main/java/coursesketch/recognition/test/score/TrainingScore.java @@ -1,11 +1,21 @@ -package coursesketch.recognition.test; +package coursesketch.recognition.test.score; + +import coursesketch.recognition.framework.RecognitionInterface; +import coursesketch.recognition.test.RecognitionTestException; /** * Created by David Windows on 7/3/2016. */ public class TrainingScore { + private final RecognitionInterface recognitionSystem; + private final String templateId; private RecognitionTestException exception; private long trainingTime; + public TrainingScore(RecognitionInterface recognitionSystem, String templateId) { + + this.recognitionSystem = recognitionSystem; + this.templateId = templateId; + } public void addException(RecognitionTestException e) { diff --git a/src/main/java/coursesketch/recognition/test/score/TrainingScoreFactory.java b/src/main/java/coursesketch/recognition/test/score/TrainingScoreFactory.java new file mode 100644 index 0000000..b3b16ad --- /dev/null +++ b/src/main/java/coursesketch/recognition/test/score/TrainingScoreFactory.java @@ -0,0 +1,10 @@ +package coursesketch.recognition.test.score; + +import coursesketch.recognition.framework.RecognitionInterface; + +/** + * Created by david on 7/5/16. + */ +public interface TrainingScoreFactory { + public TrainingScore createTrainingScore(RecognitionInterface recognitionSystem, String templateId); +}