Skip to content

Commit

Permalink
Add SamplingKeyColumnName to AutoMLExperiment
Browse files Browse the repository at this point in the history
  • Loading branch information
torronen committed May 3, 2023
1 parent a18b9cb commit 34a0ca1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, Trai
/// </summary>
/// <param name="experiment"><see cref="AutoMLExperiment"/></param>
/// <param name="dataset">dataset for cross-validation split.</param>
/// <param name="fold"></param>
/// <param name="fold">number of cross-validation folds</param>
/// <param name="samplingKeyColumnName">column name for sampling key</param>
/// <returns><see cref="AutoMLExperiment"/></returns>
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10)
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10, string samplingKeyColumnName = null)
{
var datasetManager = new CrossValidateDatasetManager()
{
Dataset = dataset,
Fold = fold,
SamplingKeyColumnName = samplingKeyColumnName,
};

experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ internal interface ICrossValidateDatasetManager
int? Fold { get; set; }

IDataView Dataset { get; set; }

string SamplingKeyColumnName { get; set; }
}

internal interface ITrainValidateDatasetManager
Expand All @@ -38,5 +40,6 @@ internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateData
public IDataView Dataset { get; set; }

public int? Fold { get; set; }
public string SamplingKeyColumnName { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public TrialResult Run(TrialSettings settings)
var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager)
{
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5);
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5, crossValidateDatasetManager.SamplingKeyColumnName);
var metrics = new List<double>();
var models = new List<ITransformer>();
foreach (var split in datasetSplit)
Expand Down
20 changes: 20 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,26 @@ public async Task AutoMLExperiment_Taxi_Fare_CV_5_Test()
result.Metric.Should().BeGreaterThan(0.5);
}

[Fact]
public async Task AutoMLExperiment_Taxi_Fare_CV_5_SamplingKey_Test()
{
var context = new MLContext(1);
var train = DatasetUtil.GetTaxiFareTrainDataView();
var experiment = context.Auto().CreateExperiment();
var label = DatasetUtil.TaxiFareLabel;
var pipeline = context.Auto().Featurizer(train, excludeColumns: new[] { label })
.Append(context.Auto().Regression(label, useLgbm: false, useSdca: false, useLbfgsPoissonRegression: false));

experiment.SetDataset(train, 5, "vendor_id")
.SetRegressionMetric(RegressionMetric.RSquared, label)
.SetPipeline(pipeline)
.SetMaxModelToExplore(1);

var result = await experiment.RunAsync();
result.Metric.Should().BeGreaterThan(0.2);
result.Metric.Should().BeLessThan(0.5);
}

[Fact]
public void AutoMLExperiment_should_use_seed_from_context_if_provided()
{
Expand Down

0 comments on commit 34a0ca1

Please sign in to comment.