diff --git a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
index 727e443852..d752153d51 100644
--- a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
+++ b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
@@ -57,14 +57,16 @@ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, Trai
///
///
/// dataset for cross-validation split.
- ///
+ /// number of cross-validation folds
+ /// column name for sampling key
///
- public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10)
+ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10, string samplingKeyColumnName = null)
{
var datasetManager = new CrossValidateDatasetManager()
{
Dataset = dataset,
Fold = fold,
+ SamplingKeyColumnName = samplingKeyColumnName,
};
experiment.ServiceCollection.AddSingleton(datasetManager);
diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
index e1cac220bc..0ab7591043 100644
--- a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
+++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
@@ -17,6 +17,8 @@ internal interface ICrossValidateDatasetManager
int? Fold { get; set; }
IDataView Dataset { get; set; }
+
+ string SamplingKeyColumnName { get; set; }
}
internal interface ITrainValidateDatasetManager
@@ -38,5 +40,6 @@ internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateData
public IDataView Dataset { get; set; }
public int? Fold { get; set; }
+ public string SamplingKeyColumnName { get; set; }
}
}
diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs
index c0b2b7d6a8..f69d38e541 100644
--- a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs
+++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs
@@ -40,7 +40,7 @@ public TrialResult Run(TrialSettings settings)
var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager)
{
- var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5);
+ var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5, crossValidateDatasetManager.SamplingKeyColumnName);
var metrics = new List();
var models = new List();
foreach (var split in datasetSplit)
diff --git a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs
index 1190c17490..1820a9447f 100644
--- a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs
+++ b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs
@@ -354,6 +354,26 @@ public async Task AutoMLExperiment_Taxi_Fare_CV_5_Test()
result.Metric.Should().BeGreaterThan(0.5);
}
+ [Fact]
+ public async Task AutoMLExperiment_Taxi_Fare_CV_5_SamplingKey_Test()
+ {
+ var context = new MLContext(1);
+ var train = DatasetUtil.GetTaxiFareTrainDataView();
+ var experiment = context.Auto().CreateExperiment();
+ var label = DatasetUtil.TaxiFareLabel;
+ var pipeline = context.Auto().Featurizer(train, excludeColumns: new[] { label })
+ .Append(context.Auto().Regression(label, useLgbm: false, useSdca: false, useLbfgsPoissonRegression: false));
+
+ experiment.SetDataset(train, 5, "vendor_id")
+ .SetRegressionMetric(RegressionMetric.RSquared, label)
+ .SetPipeline(pipeline)
+ .SetMaxModelToExplore(1);
+
+ var result = await experiment.RunAsync();
+ result.Metric.Should().BeGreaterThan(0.2);
+ result.Metric.Should().BeLessThan(0.5);
+ }
+
[Fact]
public void AutoMLExperiment_should_use_seed_from_context_if_provided()
{