diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 5507176fb48..94c920c3c67 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -256,6 +256,8 @@ public Double MacroAvgAccuracy { get { + if (_numInstances == 0) + return 0; Double macroAvgAccuracy = 0; int countOfNonEmptyClasses = 0; for (int i = 0; i < _numClasses; ++i) @@ -267,8 +269,7 @@ public Double MacroAvgAccuracy } } - Contracts.Assert(countOfNonEmptyClasses > 0); - return macroAvgAccuracy / countOfNonEmptyClasses; + return countOfNonEmptyClasses > 0 ? macroAvgAccuracy / countOfNonEmptyClasses : 0; } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs index d927ba0a439..7ea557159e3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs @@ -28,8 +28,6 @@ "OLS Linear Regression Executor", OlsLinearRegressionPredictor.LoaderSignature)] -[assembly: LoadableClass(typeof(void), typeof(OlsLinearRegressionTrainer), null, typeof(SignatureEntryPointModule), OlsLinearRegressionTrainer.LoadNameValue)] - namespace Microsoft.ML.Runtime.Learners { public sealed class OlsLinearRegressionTrainer : TrainerBase diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index e711b871171..395f6d81785 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -381,6 +381,10 @@ public static CommonOutputs.MacroOutput CrossValidate( // Set the input bindings for the CombineMetrics entry point. var combineInputBindingMap = new Dictionary>(); var combineInputMap = new Dictionary(); + + var warningsArray = new SimpleParameterBinding(nameof(combineArgs.Warnings)); + combineInputBindingMap.Add(nameof(combineArgs.Warnings), new List { warningsArray }); + combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsOutput.OutputData.VarName)); var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics)); combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List { overallArray }); combineInputMap.Add(overallArray, new SimpleVariableBinding(overallMetricsOutput.OutputData.VarName)); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 2607e427c1f..061c14a0346 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -528,6 +528,89 @@ public void TestCrossValidationMacroWithMultiClass() } Assert.Equal(0, rowCount); } + + var warnings = experiment.GetOutput(crossValidateOutput.Warnings); + using (var cursor = warnings.GetRowCursor(col => true)) + Assert.False(cursor.MoveNext()); + } + } + + [Fact] + public void TestCrossValidationMacroMultiClassWithWarnings() + { + var dataPath = GetDataPath(@"Train-Tiny-28x28.txt"); + using (var env = new TlcEnvironment(42)) + { + var subGraph = env.CreateExperiment(); + + var nop = new ML.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); + + var learnerInput = new ML.Trainers.LogisticRegressionClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); + + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); + + var filter = new ML.Transforms.RowRangeFilter(); + filter.Data = importOutput.Data; + filter.Column = "Label"; + filter.Min = 0; + filter.Max = 5; + var filterOutput = experiment.Add(filter); + + var term = new ML.Transforms.TextToKeyConverter(); + term.Column = new[] + { + new ML.Transforms.TermTransformColumn() + { + Source = "Label", Name = "Strat", Sort = ML.Transforms.TermTransformSortOrder.Value + } + }; + term.Data = filterOutput.OutputData; + var termOutput = experiment.Add(term); + + var crossValidate = new ML.Models.CrossValidator + { + Data = termOutput.OutputData, + Nodes = subGraph, + Kind = ML.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, + TransformModel = null, + StratificationColumn = "Strat" + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + importInput.SetInput(env, experiment); + experiment.Run(); + var warnings = experiment.GetOutput(crossValidateOutput.Warnings); + + var schema = warnings.Schema; + var b = schema.TryGetColumnIndex("WarningText", out int warningCol); + Assert.True(b); + using (var cursor = warnings.GetRowCursor(col => col == warningCol)) + { + var getter = cursor.GetGetter(warningCol); + + b = cursor.MoveNext(); + Assert.True(b); + var warning = default(DvText); + getter(ref warning); + Assert.Contains("test instances with class values not seen in the training set.", warning.ToString()); + b = cursor.MoveNext(); + Assert.True(b); + getter(ref warning); + Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString()); + b = cursor.MoveNext(); + Assert.False(b); + } } }