diff --git a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs index 727e443852..d752153d51 100644 --- a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs +++ b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs @@ -57,14 +57,16 @@ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, Trai /// /// /// dataset for cross-validation split. - /// + /// number of cross-validation folds + /// column name for sampling key /// - public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10) + public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10, string samplingKeyColumnName = null) { var datasetManager = new CrossValidateDatasetManager() { Dataset = dataset, Fold = fold, + SamplingKeyColumnName = samplingKeyColumnName, }; experiment.ServiceCollection.AddSingleton(datasetManager); diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs index e1cac220bc..0ab7591043 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs @@ -17,6 +17,8 @@ internal interface ICrossValidateDatasetManager int? Fold { get; set; } IDataView Dataset { get; set; } + + string SamplingKeyColumnName { get; set; } } internal interface ITrainValidateDatasetManager @@ -38,5 +40,6 @@ internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateData public IDataView Dataset { get; set; } public int? Fold { get; set; } + public string SamplingKeyColumnName { get; set; } } } diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs index c0b2b7d6a8..f69d38e541 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs @@ -40,7 +40,7 @@ public TrialResult Run(TrialSettings settings) var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter); if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager) { - var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5); + var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5, crossValidateDatasetManager.SamplingKeyColumnName); var metrics = new List(); var models = new List(); foreach (var split in datasetSplit) diff --git a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs index 1190c17490..1820a9447f 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs @@ -354,6 +354,26 @@ public async Task AutoMLExperiment_Taxi_Fare_CV_5_Test() result.Metric.Should().BeGreaterThan(0.5); } + [Fact] + public async Task AutoMLExperiment_Taxi_Fare_CV_5_SamplingKey_Test() + { + var context = new MLContext(1); + var train = DatasetUtil.GetTaxiFareTrainDataView(); + var experiment = context.Auto().CreateExperiment(); + var label = DatasetUtil.TaxiFareLabel; + var pipeline = context.Auto().Featurizer(train, excludeColumns: new[] { label }) + .Append(context.Auto().Regression(label, useLgbm: false, useSdca: false, useLbfgsPoissonRegression: false)); + + experiment.SetDataset(train, 5, "vendor_id") + .SetRegressionMetric(RegressionMetric.RSquared, label) + .SetPipeline(pipeline) + .SetMaxModelToExplore(1); + + var result = await experiment.RunAsync(); + result.Metric.Should().BeGreaterThan(0.2); + result.Metric.Should().BeLessThan(0.5); + } + [Fact] public void AutoMLExperiment_should_use_seed_from_context_if_provided() {