Skip to content

Commit

Permalink
Fix CV macro to output the warnings data view properly. (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaeldMS authored Jun 20, 2018
1 parent e5de547 commit ead943e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ public Double MacroAvgAccuracy
{
get
{
if (_numInstances == 0)
return 0;
Double macroAvgAccuracy = 0;
int countOfNonEmptyClasses = 0;
for (int i = 0; i < _numClasses; ++i)
Expand All @@ -267,8 +269,7 @@ public Double MacroAvgAccuracy
}
}

Contracts.Assert(countOfNonEmptyClasses > 0);
return macroAvgAccuracy / countOfNonEmptyClasses;
return countOfNonEmptyClasses > 0 ? macroAvgAccuracy / countOfNonEmptyClasses : 0;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
"OLS Linear Regression Executor",
OlsLinearRegressionPredictor.LoaderSignature)]

[assembly: LoadableClass(typeof(void), typeof(OlsLinearRegressionTrainer), null, typeof(SignatureEntryPointModule), OlsLinearRegressionTrainer.LoadNameValue)]

namespace Microsoft.ML.Runtime.Learners
{
public sealed class OlsLinearRegressionTrainer : TrainerBase<RoleMappedData, OlsLinearRegressionPredictor>
Expand Down
4 changes: 4 additions & 0 deletions src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
// Set the input bindings for the CombineMetrics entry point.
var combineInputBindingMap = new Dictionary<string, List<ParameterBinding>>();
var combineInputMap = new Dictionary<ParameterBinding, VariableBinding>();

var warningsArray = new SimpleParameterBinding(nameof(combineArgs.Warnings));
combineInputBindingMap.Add(nameof(combineArgs.Warnings), new List<ParameterBinding> { warningsArray });
combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsOutput.OutputData.VarName));
var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics));
combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List<ParameterBinding> { overallArray });
combineInputMap.Add(overallArray, new SimpleVariableBinding(overallMetricsOutput.OutputData.VarName));
Expand Down
83 changes: 83 additions & 0 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,89 @@ public void TestCrossValidationMacroWithMultiClass()
}
Assert.Equal(0, rowCount);
}

var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
using (var cursor = warnings.GetRowCursor(col => true))
Assert.False(cursor.MoveNext());
}
}

[Fact]
public void TestCrossValidationMacroMultiClassWithWarnings()
{
var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
using (var env = new TlcEnvironment(42))
{
var subGraph = env.CreateExperiment();

var nop = new ML.Transforms.NoOperation();
var nopOutput = subGraph.Add(nop);

var learnerInput = new ML.Trainers.LogisticRegressionClassifier
{
TrainingData = nopOutput.OutputData,
NumThreads = 1
};
var learnerOutput = subGraph.Add(learnerInput);

var experiment = env.CreateExperiment();
var importInput = new ML.Data.TextLoader(dataPath);
var importOutput = experiment.Add(importInput);

var filter = new ML.Transforms.RowRangeFilter();
filter.Data = importOutput.Data;
filter.Column = "Label";
filter.Min = 0;
filter.Max = 5;
var filterOutput = experiment.Add(filter);

var term = new ML.Transforms.TextToKeyConverter();
term.Column = new[]
{
new ML.Transforms.TermTransformColumn()
{
Source = "Label", Name = "Strat", Sort = ML.Transforms.TermTransformSortOrder.Value
}
};
term.Data = filterOutput.OutputData;
var termOutput = experiment.Add(term);

var crossValidate = new ML.Models.CrossValidator
{
Data = termOutput.OutputData,
Nodes = subGraph,
Kind = ML.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
TransformModel = null,
StratificationColumn = "Strat"
};
crossValidate.Inputs.Data = nop.Data;
crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel;
var crossValidateOutput = experiment.Add(crossValidate);

experiment.Compile();
importInput.SetInput(env, experiment);
experiment.Run();
var warnings = experiment.GetOutput(crossValidateOutput.Warnings);

var schema = warnings.Schema;
var b = schema.TryGetColumnIndex("WarningText", out int warningCol);
Assert.True(b);
using (var cursor = warnings.GetRowCursor(col => col == warningCol))
{
var getter = cursor.GetGetter<DvText>(warningCol);

b = cursor.MoveNext();
Assert.True(b);
var warning = default(DvText);
getter(ref warning);
Assert.Contains("test instances with class values not seen in the training set.", warning.ToString());
b = cursor.MoveNext();
Assert.True(b);
getter(ref warning);
Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString());
b = cursor.MoveNext();
Assert.False(b);
}
}
}

Expand Down

0 comments on commit ead943e

Please sign in to comment.