Skip to content

Commit

Permalink
Console helper bug in generated code for multiclass (dotnet#323)
Browse files Browse the repository at this point in the history
* fix

* fix test

* looping perlogclass

* fix test
  • Loading branch information
srsaggam authored and Dmitry-A committed Aug 22, 2019
1 parent 14f9d17 commit 6162944
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ namespace TestNamespace.Train
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)
{
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
}
Console.WriteLine($"************************************************************");
}

Expand Down
104 changes: 51 additions & 53 deletions src/mlnet/Templates/Console/ConsoleHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,61 +104,59 @@ namespace ");
"ole.WriteLine($\" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value betw" +
"een 0 and 1, the closer to 1, the better\");\r\n Console.WriteLine($\" " +
" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better\");\r\n " +
" Console.WriteLine($\" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.###" +
"#}, the closer to 0, the better\");\r\n Console.WriteLine($\" LogLoss " +
"for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better\")" +
";\r\n Console.WriteLine($\" LogLoss for class 3 = {metrics.PerClassLo" +
"gLoss[2]:0.####}, the closer to 0, the better\");\r\n Console.WriteLine(" +
"$\"************************************************************\");\r\n }\r\n\r\n" +
" public static void PrintMulticlassClassificationFoldsAverageMetrics(Trai" +
"nCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[] crossValResult" +
"s)\r\n {\r\n var metricsInMultipleFolds = crossValResults.Select(r" +
" => r.Metrics);\r\n\r\n var microAccuracyValues = metricsInMultipleFolds." +
"Select(m => m.AccuracyMicro);\r\n var microAccuracyAverage = microAccur" +
"acyValues.Average();\r\n var microAccuraciesStdDeviation = CalculateSta" +
"ndardDeviation(microAccuracyValues);\r\n var microAccuraciesConfidenceI" +
"nterval95 = CalculateConfidenceInterval95(microAccuracyValues);\r\n\r\n v" +
"ar macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);\r\n " +
" var macroAccuracyAverage = macroAccuracyValues.Average();\r\n " +
" var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValu" +
"es);\r\n var macroAccuraciesConfidenceInterval95 = CalculateConfidenceI" +
"nterval95(macroAccuracyValues);\r\n\r\n var logLossValues = metricsInMult" +
"ipleFolds.Select(m => m.LogLoss);\r\n var logLossAverage = logLossValue" +
"s.Average();\r\n var logLossStdDeviation = CalculateStandardDeviation(l" +
"ogLossValues);\r\n var logLossConfidenceInterval95 = CalculateConfidenc" +
"eInterval95(logLossValues);\r\n\r\n var logLossReductionValues = metricsI" +
"nMultipleFolds.Select(m => m.LogLossReduction);\r\n var logLossReductio" +
"nAverage = logLossReductionValues.Average();\r\n var logLossReductionSt" +
"dDeviation = CalculateStandardDeviation(logLossReductionValues);\r\n va" +
"r logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossRe" +
"ductionValues);\r\n\r\n Console.WriteLine($\"*****************************" +
" for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)\r\n {\r\n " +
" Console.WriteLine($\" LogLoss for class {i + 1} = {metrics.PerClassL" +
"ogLoss[i]:0.####}, the closer to 0, the better\");\r\n }\r\n Co" +
"nsole.WriteLine($\"************************************************************\")" +
";\r\n }\r\n\r\n public static void PrintMulticlassClassificationFoldsAve" +
"rageMetrics(TrainCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[" +
"] crossValResults)\r\n {\r\n var metricsInMultipleFolds = crossVal" +
"Results.Select(r => r.Metrics);\r\n\r\n var microAccuracyValues = metrics" +
"InMultipleFolds.Select(m => m.AccuracyMicro);\r\n var microAccuracyAver" +
"age = microAccuracyValues.Average();\r\n var microAccuraciesStdDeviatio" +
"n = CalculateStandardDeviation(microAccuracyValues);\r\n var microAccur" +
"aciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);\r" +
"\n\r\n var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.Ac" +
"curacyMacro);\r\n var macroAccuracyAverage = macroAccuracyValues.Averag" +
"e();\r\n var macroAccuraciesStdDeviation = CalculateStandardDeviation(m" +
"acroAccuracyValues);\r\n var macroAccuraciesConfidenceInterval95 = Calc" +
"ulateConfidenceInterval95(macroAccuracyValues);\r\n\r\n var logLossValues" +
" = metricsInMultipleFolds.Select(m => m.LogLoss);\r\n var logLossAverag" +
"e = logLossValues.Average();\r\n var logLossStdDeviation = CalculateSta" +
"ndardDeviation(logLossValues);\r\n var logLossConfidenceInterval95 = Ca" +
"lculateConfidenceInterval95(logLossValues);\r\n\r\n var logLossReductionV" +
"alues = metricsInMultipleFolds.Select(m => m.LogLossReduction);\r\n var" +
" logLossReductionAverage = logLossReductionValues.Average();\r\n var lo" +
"gLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);" +
"\r\n var logLossReductionConfidenceInterval95 = CalculateConfidenceInte" +
"rval95(logLossReductionValues);\r\n\r\n Console.WriteLine($\"*************" +
"********************************************************************************" +
"\");\r\n Console.WriteLine($\"* Metrics for Multi-class Classificat" +
"ion model \");\r\n Console.WriteLine($\"*---------------------------" +
"****************\");\r\n Console.WriteLine($\"* Metrics for Multi-c" +
"lass Classification model \");\r\n Console.WriteLine($\"*-----------" +
"--------------------------------------------------------------------------------" +
"-\");\r\n Console.WriteLine($\"* Average MicroAccuracy: {microAc" +
"curacyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}" +
") - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})\");\r\n" +
" Console.WriteLine($\"* Average MacroAccuracy: {macroAccuracy" +
"Average:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - C" +
"onfidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})\");\r\n " +
" Console.WriteLine($\"* Average LogLoss: {logLossAverage:#.##" +
"#} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 9" +
"5%: ({logLossConfidenceInterval95:#.###})\");\r\n Console.WriteLine($\"* " +
" Average LogLossReduction: {logLossReductionAverage:#.###} - Standard devi" +
"ation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logL" +
"ossReductionConfidenceInterval95:#.###})\");\r\n Console.WriteLine($\"***" +
"********************************************************************************" +
"**************************\");\r\n\r\n }\r\n\r\n public static double Calcu" +
"lateStandardDeviation(IEnumerable<double> values)\r\n {\r\n double" +
" average = values.Average();\r\n double sumOfSquaresOfDifferences = val" +
"ues.Select(val => (val - average) * (val - average)).Sum();\r\n double " +
"standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));" +
"\r\n return standardDeviation;\r\n }\r\n\r\n public static doub" +
"le CalculateConfidenceInterval95(IEnumerable<double> values)\r\n {\r\n " +
" double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / M" +
"ath.Sqrt((values.Count() - 1));\r\n return confidenceInterval95;\r\n " +
" }\r\n }\r\n}\r\n");
"-----------------\");\r\n Console.WriteLine($\"* Average MicroAccur" +
"acy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStd" +
"Deviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterva" +
"l95:#.###})\");\r\n Console.WriteLine($\"* Average MacroAccuracy: " +
" {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviat" +
"ion:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#." +
"###})\");\r\n Console.WriteLine($\"* Average LogLoss: {log" +
"LossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confi" +
"dence Interval 95%: ({logLossConfidenceInterval95:#.###})\");\r\n Consol" +
"e.WriteLine($\"* Average LogLossReduction: {logLossReductionAverage:#.###} " +
" - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Inte" +
"rval 95%: ({logLossReductionConfidenceInterval95:#.###})\");\r\n Console" +
".WriteLine($\"*******************************************************************" +
"******************************************\");\r\n\r\n }\r\n\r\n public sta" +
"tic double CalculateStandardDeviation(IEnumerable<double> values)\r\n {\r\n " +
" double average = values.Average();\r\n double sumOfSquaresOfD" +
"ifferences = values.Select(val => (val - average) * (val - average)).Sum();\r\n " +
" double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (value" +
"s.Count() - 1));\r\n return standardDeviation;\r\n }\r\n\r\n pu" +
"blic static double CalculateConfidenceInterval95(IEnumerable<double> values)\r\n " +
" {\r\n double confidenceInterval95 = 1.96 * CalculateStandardDevia" +
"tion(values) / Math.Sqrt((values.Count() - 1));\r\n return confidenceIn" +
"terval95;\r\n }\r\n }\r\n}\r\n");
return this.GenerationEnvironment.ToString();
}

Expand Down
7 changes: 4 additions & 3 deletions src/mlnet/Templates/Console/ConsoleHelper.tt
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ namespace <#= Namespace #>.Train
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)
{
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
}
Console.WriteLine($"************************************************************");
}

Expand Down

0 comments on commit 6162944

Please sign in to comment.