Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf improvement for TopK Accuracy and return all topK in Classification Evaluator #5395

Merged
merged 29 commits into from
Dec 9, 2020
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5fbf740
Fix for issue 744
jasallen Sep 8, 2020
1747d3e
cleanup
jasallen Sep 9, 2020
32c244a
fixing report output
jasallen Sep 12, 2020
968b58d
fixedTestReferenceOutputs
jasallen Sep 12, 2020
b7ded43
Fixed test reference outputs for NetCore31
jasallen Sep 12, 2020
685eeb4
change top k acc output string format
jasallen Nov 5, 2020
1eacec7
Ranking algorithm now uses first appearance in dataset rather than wo…
jasallen Nov 6, 2020
ea057ff
fixed benchmark
jasallen Nov 6, 2020
ac08554
various minor changes from code review
jasallen Nov 6, 2020
f0de3ea
limit TopK to OutputTopKAcc parameter
jasallen Nov 6, 2020
30fbd6f
top k output name changes
jasallen Nov 6, 2020
495b4b0
make old TopK readOnly
jasallen Nov 6, 2020
c3afe15
restored old baselineOutputs since respecting outputTopK param means …
jasallen Nov 6, 2020
bfcda22
fix test fails, re-add names parameter
jasallen Nov 6, 2020
563768c
Clean up commented code
jasallen Nov 6, 2020
4a5597a
that'll teach me to edit from the github webpage
jasallen Nov 6, 2020
71390bd
use existing method, fix nits
jasallen Nov 19, 2020
32ab9fa
Slight comment change
jasallen Nov 20, 2020
db2b6b5
Comment change / Touch to kick off build pipeline
jasallen Nov 21, 2020
0d0493b
fix whitespace
jasallen Nov 23, 2020
e6aec98
Merge branch 'master' into jasallenbranch
antoniovs1029 Dec 3, 2020
05e7f91
Added new test
antoniovs1029 Dec 4, 2020
49786ed
Code formatting nits
justinormont Dec 8, 2020
9259031
Code formatting nit
justinormont Dec 8, 2020
98458ba
Fixed undefined rankofCorrectLabel and trailing whitespace warning
antoniovs1029 Dec 8, 2020
86f5c3f
Removed _numUnknownClassInstances and added test for unknown labels
antoniovs1029 Dec 8, 2020
741e9fb
Add weight to seenRanks
antoniovs1029 Dec 8, 2020
dadf793
Nits
antoniovs1029 Dec 9, 2020
9e67751
Removed FastTree import
antoniovs1029 Dec 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix for issue 744
jasallen committed Sep 12, 2020

Unverified

This user has not yet uploaded their public signing key.
commit 5fbf7404aa01e967d7a1fcd73b237214c05aac8c
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ public static void Example()
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123, greatestLabel: 4));

// Run the model on test data set.
var transformedTestData = model.Transform(testData);
@@ -57,7 +57,7 @@ public static void Example()
reuseRowObject: false).ToList();

// Look at 5 predictions
foreach (var p in predictions.Take(5))
foreach (var p in predictions.Take(100))
Console.WriteLine($"Label: {p.Label}, " +
$"Prediction: {p.PredictedLabel}");

@@ -70,7 +70,7 @@ public static void Example()

// Evaluate the overall metrics
var metrics = mlContext.MulticlassClassification
.Evaluate(transformedTestData);
.Evaluate(transformedTestData, topKPredictionCount:3);

PrintMetrics(metrics);

@@ -93,18 +93,18 @@ public static void Example()
// Generates random uniform doubles in [-0.5, 0.5)
// range with labels 1, 2 or 3.
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed=0)
int seed=0, int greatestLabel = 4)

{
var random = new Random(seed);
float randomFloat() => (float)(random.NextDouble() - 0.5);
for (int i = 0; i < count; i++)
{
// Generate Labels that are integers 1, 2 or 3
var label = random.Next(1, 4);
var label = random.Next(1, greatestLabel);
yield return new DataPoint
{
Label = (uint)label,
Label = label.ToString() + "@",
// Create random features that are correlated with the label.
// The feature values are slightly increased by adding a
// constant multiple of label.
@@ -119,7 +119,7 @@ private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
// such examples.
private class DataPoint
{
public uint Label { get; set; }
public string Label { get; set; }
[VectorType(20)]
public float[] Features { get; set; }
}
@@ -142,6 +142,11 @@ public static void PrintMetrics(MulticlassClassificationMetrics metrics)
Console.WriteLine(
$"Log Loss Reduction: {metrics.LogLossReduction:F2}\n");

for (int k=0; k < metrics.TopKAccuracyForAllK.Count(); k++)
{
Console.WriteLine($"Top {k} Accuracy: {metrics.TopKAccuracyForAllK[k]:F2}");
}

Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
}
}
34 changes: 19 additions & 15 deletions docs/samples/Microsoft.ML.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
using System;
using System.Reflection;
using Samples.Dynamic;
using Samples.Dynamic.Trainers.MulticlassClassification;

namespace Microsoft.ML.Samples
{
public static class Program
{
public static void Main(string[] args) => RunAll(args == null || args.Length == 0 ? null : args[0]);


internal static void RunAll(string name = null)
{
int samples = 0;
foreach (var type in Assembly.GetExecutingAssembly().GetTypes())
{
if (name == null || name.Equals(type.Name))
{
var sample = type.GetMethod("Example", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy);
//int samples = 0;
//foreach (var type in Assembly.GetExecutingAssembly().GetTypes())
//{
// if (name == null || name.Equals(type.Name))
// {
// var sample = type.GetMethod("Example", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy);

// if (sample != null)
// {
// Console.WriteLine(type.Name);
// sample.Invoke(null, null);
// samples++;
// }
// }
//}

if (sample != null)
{
Console.WriteLine(type.Name);
sample.Invoke(null, null);
samples++;
}
}
}
SdcaMaximumEntropy.Example();

Console.WriteLine("Number of samples that ran without any exception: " + samples);
Console.WriteLine("Number of samples that ran without any exception: ");
}
}
}
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.Data
@@ -81,6 +82,11 @@ public sealed class MulticlassClassificationMetrics
/// </summary>
public int TopKPredictionCount { get; }
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Returns the top K for all K from 1 to the number of classes
/// </summary>
public IReadOnlyList<double> TopKAccuracyForAllK { get; }

/// <summary>
/// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier
/// with respect to how much the predicted probabilities diverge from the true class label. Lower
@@ -114,9 +120,10 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult,
MacroAccuracy = FetchDouble(MulticlassClassificationEvaluator.AccuracyMacro);
LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss);
LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction);
TopKAccuracyForAllK = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray();
jasallen marked this conversation as resolved.
Show resolved Hide resolved
TopKPredictionCount = topKPredictionCount;
if (topKPredictionCount > 0)
TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy);
TopKAccuracy = TopKAccuracyForAllK[topKPredictionCount-1];

var perClassLogLoss = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss);
PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray();
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ public sealed class Arguments
public const string AccuracyMicro = "Accuracy(micro-avg)";
public const string AccuracyMacro = "Accuracy(macro-avg)";
public const string TopKAccuracy = "Top K accuracy";
public const string AllTopKAccuracy = "Top K accuracy(All K)";
public const string PerClassLogLoss = "Per class log-loss";
public const string LogLoss = "Log-loss";
public const string LogLossReduction = "Log-loss reduction";
@@ -60,15 +61,13 @@ public enum Metrics
internal const string LoadName = "MultiClassClassifierEvaluator";

private readonly int? _outputTopKAcc;
private readonly bool _names;
jasallen marked this conversation as resolved.
Show resolved Hide resolved

public MulticlassClassificationEvaluator(IHostEnvironment env, Arguments args)
: base(env, LoadName)
{
Host.AssertValue(args, "args");
Host.CheckUserArg(args.OutputTopKAcc == null || args.OutputTopKAcc > 0, nameof(args.OutputTopKAcc));
_outputTopKAcc = args.OutputTopKAcc;
_names = args.Names;
}

private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
@@ -147,6 +146,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
var logLoss = new List<double>();
var logLossRed = new List<double>();
var topKAcc = new List<double>();
var allTopK = new List<double[]>();
var perClassLogLoss = new List<double[]>();
var counts = new List<double[]>();
var weights = new List<double[]>();
@@ -172,6 +172,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
logLossRed.Add(agg.UnweightedCounters.Reduction);
if (agg.UnweightedCounters.OutputTopKAcc > 0)
topKAcc.Add(agg.UnweightedCounters.TopKAccuracy);
allTopK.Add(agg.UnweightedCounters.AllTopKAccuracy);
jasallen marked this conversation as resolved.
Show resolved Hide resolved
perClassLogLoss.Add(agg.UnweightedCounters.PerClassLogLoss);

confStratCol.AddRange(agg.UnweightedCounters.ConfusionTable.Select(x => stratColKey));
@@ -189,6 +190,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
logLossRed.Add(agg.WeightedCounters.Reduction);
if (agg.WeightedCounters.OutputTopKAcc > 0)
topKAcc.Add(agg.WeightedCounters.TopKAccuracy);
allTopK.Add(agg.WeightedCounters.AllTopKAccuracy);
perClassLogLoss.Add(agg.WeightedCounters.PerClassLogLoss);
weights.AddRange(agg.WeightedCounters.ConfusionTable);
}
@@ -211,6 +213,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
overallDvBldr.AddColumn(LogLossReduction, NumberDataViewType.Double, logLossRed.ToArray());
if (aggregator.UnweightedCounters.OutputTopKAcc > 0)
overallDvBldr.AddColumn(TopKAccuracy, NumberDataViewType.Double, topKAcc.ToArray());
overallDvBldr.AddColumn(AllTopKAccuracy, NumberDataViewType.Double, allTopK.ToArray());
overallDvBldr.AddColumn(PerClassLogLoss, aggregator.GetSlotNames, NumberDataViewType.Double, perClassLogLoss.ToArray());

var confDvBldr = new ArrayDataViewBuilder(Host);
@@ -246,9 +249,11 @@ public sealed class Counters
private double _totalLogLoss;
private double _numInstances;
private double _numCorrect;
private double _numCorrectTopK;
private int _numUnknownClassInstances;
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
private readonly double[] _sumWeightsOfClass;
private readonly double[] _totalPerClassLogLoss;
private readonly long[] _seenRanks;

public readonly double[][] ConfusionTable;

public double MicroAvgAccuracy { get { return _numInstances > 0 ? _numCorrect / _numInstances : 0; } }
@@ -291,7 +296,8 @@ public double Reduction
}
}

public double TopKAccuracy { get { return _numInstances > 0 ? _numCorrectTopK / _numInstances : 0; } }
public double TopKAccuracy => !(OutputTopKAcc is null) ? AllTopKAccuracy[OutputTopKAcc.Value] : 0d;
public double[] AllTopKAccuracy => CumulativeSum(_seenRanks.Select(l => l / (double)(_numInstances - _numUnknownClassInstances))).ToArray();

// The per class average log loss is calculated by dividing the weighted sum of the log loss of examples
// in each class by the total weight of examples in that class.
@@ -316,14 +322,12 @@ public Counters(int numClasses, int? outputTopKAcc)
ConfusionTable = new double[numClasses][];
for (int i = 0; i < ConfusionTable.Length; i++)
ConfusionTable[i] = new double[numClasses];

_seenRanks = new long[numClasses + 1];
}

public void Update(int[] indices, double loglossCurr, int label, float weight)
public void Update(int seenRank, int assigned, double loglossCurr, int label, float weight)
{
Contracts.Assert(Utils.Size(indices) == _numClasses);

int assigned = indices[0];

_numInstances += weight;

if (label < _numClasses)
@@ -334,23 +338,34 @@ public void Update(int[] indices, double loglossCurr, int label, float weight)
if (label < _numClasses)
_totalPerClassLogLoss[label] += loglossCurr * weight;

if (assigned == label)
_seenRanks[seenRank]++;
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved

if (seenRank == 0) //prediction matched label
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
{
_numCorrect += weight;
ConfusionTable[label][label] += weight;
_numCorrectTopK += weight;
}
else if (label < _numClasses)
{
if (OutputTopKAcc > 0)
{
int idx = Array.IndexOf(indices, label);
if (0 <= idx && idx < OutputTopKAcc)
_numCorrectTopK += weight;
}
ConfusionTable[label][assigned] += weight;
}
else
{
_numUnknownClassInstances++;
}
}

private static IEnumerable<double> CumulativeSum(IEnumerable<double> s)
{
double sum = 0;
;
jasallen marked this conversation as resolved.
Show resolved Hide resolved
foreach (var x in s)
{
sum += x;
yield return sum;
}
}

}

private ValueGetter<float> _labelGetter;
@@ -359,7 +374,6 @@ public void Update(int[] indices, double loglossCurr, int label, float weight)

private VBuffer<float> _scores;
private readonly float[] _scoresArr;
private int[] _indicesArr;

private const float Epsilon = (float)1e-15;

@@ -380,6 +394,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, int s
Host.Assert(Utils.Size(classNames) == scoreVectorSize);

_scoresArr = new float[scoreVectorSize];

UnweightedCounters = new Counters(scoreVectorSize, outputTopKAcc);
Weighted = weighted;
WeightedCounters = Weighted ? new Counters(scoreVectorSize, outputTopKAcc) : null;
@@ -400,6 +415,7 @@ internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema sche

if (schema.Weight.HasValue)
_weightGetter = row.GetGetter<float>(schema.Weight.Value);

}

public override void ProcessRow()
@@ -437,16 +453,12 @@ public override void ProcessRow()
}
}

// Sort classes by prediction strength.
// Use stable OrderBy instead of Sort(), which may give different results on different machines.
if (Utils.Size(_indicesArr) < _scoresArr.Length)
_indicesArr = new int[_scoresArr.Length];
int j = 0;
foreach (var index in Enumerable.Range(0, _scoresArr.Length).OrderByDescending(i => _scoresArr[i]))
_indicesArr[j++] = index;

var intLabel = (int)label;

var assigned = Array.IndexOf(_scoresArr, _scoresArr.Max()); //perf could be improved

var wasKnownLabel = true;

// log-loss
double logloss;
if (intLabel < _scoresArr.Length)
@@ -461,11 +473,21 @@ public override void ProcessRow()
// Penalize logloss if the label was not seen during training
logloss = -Math.Log(Epsilon);
_numUnknownClassInstances++;
wasKnownLabel = false;
}

UnweightedCounters.Update(_indicesArr, logloss, intLabel, 1);
// Get the probability that the CORRECT label has: (best case is that it's the highest probability):
var correctProba = !wasKnownLabel ? 0 : _scoresArr[intLabel];

// Find the rank of the *correct* label (in Scores[]). If 0 => Good, correct. And the lower the better.
jasallen marked this conversation as resolved.
Show resolved Hide resolved
// The rank will be from 0 to N. (Not N-1).
jasallen marked this conversation as resolved.
Show resolved Hide resolved
// Problem: What if we have probabilities that are equal to the correct prediction (eg, .6 .1 .1 .1 .1).
// This actually happens a lot with some models. Here we assign the worst rank in the case of a tie (so 4 in this example)
var correctRankWorstCase = !wasKnownLabel ? _scoresArr.Length : _scoresArr.Count(score => score >= correctProba) - 1;

UnweightedCounters.Update(correctRankWorstCase, assigned, logloss, intLabel, 1);
if (WeightedCounters != null)
WeightedCounters.Update(_indicesArr, logloss, intLabel, weight);
WeightedCounters.Update(correctRankWorstCase, assigned, logloss, intLabel, weight);
}

protected override List<string> GetWarningsCore()
@@ -909,6 +931,7 @@ private protected override IDataView CombineOverallMetricsCore(IDataView[] metri
for (int i = 0; i < metrics.Length; i++)
{
var idv = metrics[i];
idv = DropAllTopKColumn(idv);
if (!_outputPerClass)
idv = DropPerClassColumn(idv);

@@ -964,6 +987,15 @@ private IDataView DropPerClassColumn(IDataView input)
return input;
}

private IDataView DropAllTopKColumn(IDataView input)
jasallen marked this conversation as resolved.
Show resolved Hide resolved
{
if (input.Schema.TryGetColumnIndex(MulticlassClassificationEvaluator.AllTopKAccuracy, out int AllTopKCol))
{
input = ColumnSelectingTransformer.CreateDrop(Host, input, MulticlassClassificationEvaluator.AllTopKAccuracy);
}
return input;
}

public override IEnumerable<MetricColumn> GetOverallMetricColumns()
{
yield return new MetricColumn("AccuracyMicro", MulticlassClassificationEvaluator.AccuracyMicro);
Loading