Skip to content

Commit 12cff38

Browse files
Bug bash feedback Feb 27. API changes and sample changes (dotnet#240)
* Bug bash feedback Feb 27. API changes Sample changes Exception fix
1 parent d8d9295 commit 12cff38

20 files changed

+116
-103
lines changed

src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@
66

77
namespace Microsoft.ML.Auto
88
{
9-
public sealed class AutoInferenceCatalog
9+
public sealed class AutoMLCatalog
1010
{
1111
private readonly MLContext _context;
1212

13-
internal AutoInferenceCatalog(MLContext context)
13+
internal AutoMLCatalog(MLContext context)
1414
{
1515
_context = context;
1616
}
1717

18-
public RegressionExperiment CreateRegressionExperiment(uint maxInferenceTimeInSeconds)
18+
public RegressionExperiment CreateRegressionExperiment(uint maxExperimentTimeInSeconds)
1919
{
2020
return new RegressionExperiment(_context, new RegressionExperimentSettings()
2121
{
22-
MaxInferenceTimeInSeconds = maxInferenceTimeInSeconds
22+
MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds
2323
});
2424
}
2525

@@ -28,11 +28,11 @@ public RegressionExperiment CreateRegressionExperiment(RegressionExperimentSetti
2828
return new RegressionExperiment(_context, experimentSettings);
2929
}
3030

31-
public BinaryClassificationExperiment CreateBinaryClassificationExperiment(uint maxInferenceTimeInSeconds)
31+
public BinaryClassificationExperiment CreateBinaryClassificationExperiment(uint maxExperimentTimeInSeconds)
3232
{
3333
return new BinaryClassificationExperiment(_context, new BinaryExperimentSettings()
3434
{
35-
MaxInferenceTimeInSeconds = maxInferenceTimeInSeconds
35+
MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds
3636
});
3737
}
3838

@@ -41,11 +41,11 @@ public BinaryClassificationExperiment CreateBinaryClassificationExperiment(Binar
4141
return new BinaryClassificationExperiment(_context, experimentSettings);
4242
}
4343

44-
public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(uint maxInferenceTimeInSeconds)
44+
public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(uint maxExperimentTimeInSeconds)
4545
{
4646
return new MulticlassClassificationExperiment(_context, new MulticlassExperimentSettings()
4747
{
48-
MaxInferenceTimeInSeconds = maxInferenceTimeInSeconds
48+
MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds
4949
});
5050
}
5151

src/Microsoft.ML.Auto/API/ExperimentSettings.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Microsoft.ML.Auto
88
{
99
public class ExperimentSettings
1010
{
11-
public uint MaxInferenceTimeInSeconds = 24 * 60 * 60;
11+
public uint MaxExperimentTimeInSeconds = 24 * 60 * 60;
1212
public CancellationToken CancellationToken;
1313

1414
internal bool EnableCaching;

src/Microsoft.ML.Auto/API/MLContextExtension.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ namespace Microsoft.ML.Auto
66
{
77
public static class MLContextExtension
88
{
9-
public static AutoInferenceCatalog AutoInference(this MLContext mlContext)
9+
public static AutoMLCatalog Auto(this MLContext mlContext)
1010
{
11-
return new AutoInferenceCatalog(mlContext);
11+
return new AutoMLCatalog(mlContext);
1212
}
1313
}
1414
}

src/Microsoft.ML.Auto/API/RunResult.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Microsoft.ML.Auto
99
{
1010
public sealed class RunResult<T>
1111
{
12-
public readonly T Metrics;
12+
public readonly T ValidationMetrics;
1313
public readonly ITransformer Model;
1414
public readonly Exception Exception;
1515
public readonly string TrainerName;
@@ -27,7 +27,7 @@ internal RunResult(
2727
int pipelineInferenceTimeInSeconds)
2828
{
2929
Model = model;
30-
Metrics = metrics;
30+
ValidationMetrics = metrics;
3131
Pipeline = pipeline;
3232
Exception = exception;
3333
RuntimeInSeconds = runtimeInSeconds;

src/Microsoft.ML.Auto/Experiment/Experiment.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ public List<RunResult<T>> Execute()
116116
iterationResults.Add(iterationResult);
117117

118118
// if model is perfect, break
119-
if (_metricsAgent.IsModelPerfect(iterationResult.Metrics))
119+
if (_metricsAgent.IsModelPerfect(iterationResult.ValidationMetrics))
120120
{
121121
break;
122122
}
123123

124124
} while (_history.Count < _experimentSettings.MaxModels &&
125125
!_experimentSettings.CancellationToken.IsCancellationRequested &&
126-
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxInferenceTimeInSeconds);
126+
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds);
127127

128128
return iterationResults;
129129
}

src/Microsoft.ML.Auto/Utils/RunResultUtil.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ internal class RunResultUtil
1212
public static RunResult<T> GetBestRunResult<T>(IEnumerable<RunResult<T>> results,
1313
IMetricsAgent<T> metricsAgent)
1414
{
15-
results = results.Where(r => r.Metrics != null);
15+
results = results.Where(r => r.ValidationMetrics != null);
1616
if (!results.Any()) { return null; }
17-
double maxScore = results.Select(r => metricsAgent.GetScore(r.Metrics)).Max();
18-
return results.First(r => metricsAgent.GetScore(r.Metrics) == maxScore);
17+
double maxScore = results.Select(r => metricsAgent.GetScore(r.ValidationMetrics)).Max();
18+
return results.First(r => metricsAgent.GetScore(r.ValidationMetrics) == maxScore);
1919
}
2020
}
2121
}

src/Samples/AutoTrainBinaryClassification.cs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.IO;
7+
using System.Linq;
78
using Microsoft.Data.DataView;
89
using Microsoft.ML;
910
using Microsoft.ML.Auto;
@@ -18,31 +19,34 @@ public class AutoTrainBinaryClassification
1819
private static string TestDataPath = $"{BaseDatasetsLocation}/wikipedia-detox-250-line-test.tsv";
1920
private static string ModelPath = $"{BaseDatasetsLocation}/SentimentModel.zip";
2021
private static string LabelColumn = "Sentiment";
22+
private static uint ExperimentTime = 60;
2123

2224
public static void Run()
2325
{
2426
MLContext mlContext = new MLContext();
2527

2628
// STEP 1: Infer columns
27-
var columnInference = mlContext.AutoInference().InferColumns(TrainDataPath, LabelColumn);
29+
var columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn);
2830

2931
// STEP 2: Load data
30-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
31-
IDataView trainDataView = textLoader.Read(TrainDataPath);
32-
IDataView testDataView = textLoader.Read(TestDataPath);
32+
var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
33+
var trainDataView = textLoader.Read(TrainDataPath);
34+
var testDataView = textLoader.Read(TestDataPath);
3335

3436
// STEP 3: Auto featurize, auto train and auto hyperparameter tune
35-
Console.WriteLine($"Invoking new AutoML binary classification experiment...");
36-
var runResults = mlContext.AutoInference()
37-
.CreateBinaryClassificationExperiment(60)
37+
Console.WriteLine($"Running AutoML binary classification experiment for {ExperimentTime} seconds...");
38+
var runResults = mlContext.Auto()
39+
.CreateBinaryClassificationExperiment(ExperimentTime)
3840
.Execute(trainDataView, LabelColumn);
3941

4042
// STEP 4: Print metric from the best model
4143
var best = runResults.Best();
42-
Console.WriteLine($"Accuracy of best model from validation data: {best.Metrics.Accuracy}");
44+
Console.WriteLine($"Total models produced: {runResults.Count()}");
45+
Console.WriteLine($"Best model's trainer: {best.TrainerName}");
46+
Console.WriteLine($"Accuracy of best model from validation data: {best.ValidationMetrics.Accuracy}");
4347

4448
// STEP 5: Evaluate test data
45-
IDataView testDataViewWithBestScore = best.Model.Transform(testDataView);
49+
var testDataViewWithBestScore = best.Model.Transform(testDataView);
4650
var testMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testDataViewWithBestScore, label: LabelColumn);
4751
Console.WriteLine($"Accuracy of best model on test data: {testMetrics.Accuracy}");
4852

@@ -51,7 +55,7 @@ public static void Run()
5155
best.Model.SaveTo(mlContext, fs);
5256

5357
Console.WriteLine("Press any key to continue...");
54-
Console.ReadLine();
58+
Console.ReadKey();
5559
}
5660
}
5761
}

src/Samples/AutoTrainMulticlassClassification.cs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.IO;
7+
using System.Linq;
78
using Microsoft.Data.DataView;
89
using Microsoft.ML;
910
using Microsoft.ML.Auto;
@@ -17,31 +18,34 @@ public class AutoTrainMulticlassClassification
1718
private static string TrainDataPath = $"{BaseDatasetsLocation}/iris-train.txt";
1819
private static string TestDataPath = $"{BaseDatasetsLocation}/iris-test.txt";
1920
private static string ModelPath = $"{BaseDatasetsLocation}/IrisClassificationModel.zip";
21+
private static uint ExperimentTime = 60;
2022

2123
public static void Run()
2224
{
2325
MLContext mlContext = new MLContext();
2426

2527
// STEP 1: Infer columns
26-
var columnInference = mlContext.AutoInference().InferColumns(TrainDataPath);
28+
var columnInference = mlContext.Auto().InferColumns(TrainDataPath);
2729

2830
// STEP 2: Load data
29-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
30-
IDataView trainDataView = textLoader.Read(TrainDataPath);
31-
IDataView testDataView = textLoader.Read(TestDataPath);
31+
var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
32+
var trainDataView = textLoader.Read(TrainDataPath);
33+
var testDataView = textLoader.Read(TestDataPath);
3234

3335
// STEP 3: Auto featurize, auto train and auto hyperparameter tune
34-
Console.WriteLine($"Invoking new AutoML multiclass classification experiment...");
35-
var runResults = mlContext.AutoInference()
36+
Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds...");
37+
var runResults = mlContext.Auto()
3638
.CreateMulticlassClassificationExperiment(60)
3739
.Execute(trainDataView);
3840

3941
// STEP 4: Print metric from the best model
4042
var best = runResults.Best();
41-
Console.WriteLine($"AccuracyMacro of best model from validation data: {best.Metrics.AccuracyMacro}");
43+
Console.WriteLine($"Total models produced: {runResults.Count()}");
44+
Console.WriteLine($"Best model's trainer: {best.TrainerName}");
45+
Console.WriteLine($"AccuracyMacro of best model from validation data: {best.ValidationMetrics.AccuracyMacro}");
4246

4347
// STEP 5: Evaluate test data
44-
IDataView testDataViewWithBestScore = best.Model.Transform(testDataView);
48+
var testDataViewWithBestScore = best.Model.Transform(testDataView);
4549
var testMetrics = mlContext.MulticlassClassification.Evaluate(testDataViewWithBestScore);
4650
Console.WriteLine($"AccuracyMacro of best model on test data: {testMetrics.AccuracyMacro}");
4751

@@ -50,7 +54,7 @@ public static void Run()
5054
best.Model.SaveTo(mlContext, fs);
5155

5256
Console.WriteLine("Press any key to continue...");
53-
Console.ReadLine();
57+
Console.ReadKey();
5458
}
5559
}
5660
}

src/Samples/AutoTrainRegression.cs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.IO;
7+
using System.Linq;
78
using Microsoft.Data.DataView;
89
using Microsoft.ML;
910
using Microsoft.ML.Auto;
@@ -18,31 +19,34 @@ static class AutoTrainRegression
1819
private static string TestDataPath = $"{BaseDatasetsLocation}/taxi-fare-test.csv";
1920
private static string ModelPath = $"{BaseDatasetsLocation}/TaxiFareModel.zip";
2021
private static string LabelColumn = "fare_amount";
22+
private static uint ExperimentTime = 60;
2123

2224
public static void Run()
2325
{
2426
MLContext mlContext = new MLContext();
2527

2628
// STEP 1: Infer columns
27-
var columnInference = mlContext.AutoInference().InferColumns(TrainDataPath, LabelColumn);
29+
var columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn);
2830

2931
// STEP 2: Load data
30-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
31-
IDataView trainDataView = textLoader.Read(TrainDataPath);
32-
IDataView testDataView = textLoader.Read(TestDataPath);
32+
var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
33+
var trainDataView = textLoader.Read(TrainDataPath);
34+
var testDataView = textLoader.Read(TestDataPath);
3335

3436
// STEP 3: Auto featurize, auto train and auto hyperparameter tune
35-
Console.WriteLine($"Invoking new AutoML regression experiment...");
36-
var runResults = mlContext.AutoInference()
37+
Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds...");
38+
var runResults = mlContext.Auto()
3739
.CreateRegressionExperiment(60)
3840
.Execute(trainDataView, LabelColumn);
3941

4042
// STEP 4: Print metric from best model
4143
var best = runResults.Best();
42-
Console.WriteLine($"RSquared of best model from validation data: {best.Metrics.RSquared}");
44+
Console.WriteLine($"Total models produced: {runResults.Count()}");
45+
Console.WriteLine($"Best model's trainer: {best.TrainerName}");
46+
Console.WriteLine($"RSquared of best model from validation data: {best.ValidationMetrics.RSquared}");
4347

4448
// STEP 5: Evaluate test data
45-
IDataView testDataViewWithBestScore = best.Model.Transform(testDataView);
49+
var testDataViewWithBestScore = best.Model.Transform(testDataView);
4650
var testMetrics = mlContext.Regression.Evaluate(testDataViewWithBestScore, label: LabelColumn);
4751
Console.WriteLine($"RSquared of best model on test data: {testMetrics.RSquared}");
4852

@@ -51,7 +55,7 @@ public static void Run()
5155
best.Model.SaveTo(mlContext, fs);
5256

5357
Console.WriteLine("Press any key to continue...");
54-
Console.ReadLine();
58+
Console.ReadKey();
5559
}
5660
}
5761
}

src/Samples/Cancellation.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ public static void Run()
2626
MLContext mlContext = new MLContext();
2727

2828
// STEP 1: Infer columns
29-
var columnInference = mlContext.AutoInference().InferColumns(TrainDataPath, LabelColumn, ',');
29+
var columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn, ',');
3030

3131
// STEP 2: Load data
32-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
33-
IDataView trainDataView = textLoader.Read(TrainDataPath);
34-
IDataView testDataView = textLoader.Read(TestDataPath);
32+
var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs);
33+
var trainDataView = textLoader.Read(TrainDataPath);
34+
var testDataView = textLoader.Read(TestDataPath);
3535

3636
int cancelAfterInSeconds = 20;
3737
CancellationTokenSource cts = new CancellationTokenSource();
@@ -40,19 +40,19 @@ public static void Run()
4040
Stopwatch watch = Stopwatch.StartNew();
4141

4242
// STEP 3: Auto inference with a cancellation token
43-
Console.WriteLine($"Invoking new AutoML regression experiment...");
44-
var runResults = mlContext.AutoInference()
43+
Console.WriteLine($"Invoking an experiment that will be cancelled after {cancelAfterInSeconds} seconds");
44+
var runResults = mlContext.Auto()
4545
.CreateRegressionExperiment(new RegressionExperimentSettings()
4646
{
47-
MaxInferenceTimeInSeconds = 60,
47+
MaxExperimentTimeInSeconds = 60,
4848
CancellationToken = cts.Token
4949
})
5050
.Execute(trainDataView, LabelColumn);
5151

5252
Console.WriteLine($"{runResults.Count()} models were returned after {cancelAfterInSeconds} seconds");
5353

5454
Console.WriteLine("Press any key to continue...");
55-
Console.ReadLine();
55+
Console.ReadKey();
5656
}
5757
}
5858
}

0 commit comments

Comments
 (0)