Skip to content

Commit

Permalink
propagate root MLContext thru AutoML (instead of creating our own) (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored Feb 21, 2019
1 parent 8a9f1f2 commit 2d8a6ed
Show file tree
Hide file tree
Showing 19 changed files with 127 additions and 97 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Auto/API/Pipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ internal Pipeline()
{
}

public IEstimator<ITransformer> ToEstimator()
public IEstimator<ITransformer> ToEstimator(MLContext context)
{
var inferredPipeline = SuggestedPipeline.FromPipeline(this);
var inferredPipeline = SuggestedPipeline.FromPipeline(context, this);
return inferredPipeline.ToEstimator();
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Auto/AutoFitter/AutoFitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public AutoFitter(MLContext context,
{
if (validationData == null)
{
(trainData, validationData) = context.Regression.TestValidateSplit(trainData);
(trainData, validationData) = context.Regression.TestValidateSplit(context, trainData);
}
_trainData = trainData;
_validationData = validationData;
Expand Down Expand Up @@ -85,7 +85,7 @@ public List<RunResult<T>> Fit()
var getPiplelineStopwatch = Stopwatch.StartNew();

// get next pipeline
pipeline = PipelineSuggester.GetNextInferredPipeline(_history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);
pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);

getPiplelineStopwatch.Stop();

Expand Down
12 changes: 5 additions & 7 deletions src/Microsoft.ML.Auto/AutoFitter/SuggestedPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ internal class SuggestedPipeline

public SuggestedPipeline(IEnumerable<SuggestedTransform> transforms,
SuggestedTrainer trainer,
MLContext context = null,
MLContext context,
bool autoNormalize = true)
{
Transforms = transforms.Select(t => t.Clone()).ToList();
Trainer = trainer.Clone();
_context = context ?? new MLContext();
_context = context;

if(autoNormalize)
{
Expand Down Expand Up @@ -64,10 +64,8 @@ public Pipeline ToPipeline()
return new Pipeline(pipelineElements.ToArray());
}

public static SuggestedPipeline FromPipeline(Pipeline pipeline)
public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipeline)
{
var context = new MLContext();

var transforms = new List<SuggestedTransform>();
SuggestedTrainer trainer = null;

Expand All @@ -84,13 +82,13 @@ public static SuggestedPipeline FromPipeline(Pipeline pipeline)
{
var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name);
var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName);
var estimator = estimatorExtension.CreateInstance(new MLContext(), pipelineNode);
var estimator = estimatorExtension.CreateInstance(context, pipelineNode);
var transform = new SuggestedTransform(pipelineNode, estimator);
transforms.Add(transform);
}
}

return new SuggestedPipeline(transforms, trainer, null, false);
return new SuggestedPipeline(transforms, trainer, context, false);
}

public IEstimator<ITransformer> ToEstimator()
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Auto/AutoFitter/SuggestedPipelineResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ public SuggestedPipelineResult(SuggestedPipeline pipeline, double score, bool ru
RunSucceded = runSucceeded;
}

public static SuggestedPipelineResult FromPipelineRunResult(PipelineScore pipelineRunResult)
public static SuggestedPipelineResult FromPipelineRunResult(MLContext context, PipelineScore pipelineRunResult)
{
return new SuggestedPipelineResult(SuggestedPipeline.FromPipeline(pipelineRunResult.Pipeline), pipelineRunResult.Score, pipelineRunResult.RunSucceded);
return new SuggestedPipelineResult(SuggestedPipeline.FromPipeline(context, pipelineRunResult.Pipeline), pipelineRunResult.Score, pipelineRunResult.RunSucceded);
}

public IRunResult ToRunResult(bool isMetricMaximizing)
Expand Down
19 changes: 9 additions & 10 deletions src/Microsoft.ML.Auto/AutoMlUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,36 @@ public static void Assert(bool boolVal, string message = null)
}
}

public static IDataView Take(this IDataView data, int count)
public static IDataView Take(this IDataView data, MLContext context, int count)
{
var context = new MLContext();
return TakeFilter.Create(context, data, count);
}

public static IDataView DropLastColumn(this IDataView data)
public static IDataView DropLastColumn(this IDataView data, MLContext context)
{
return new MLContext().Transforms.DropColumns(data.Schema[data.Schema.Count - 1].Name).Fit(data).Transform(data);
return context.Transforms.DropColumns(data.Schema[data.Schema.Count - 1].Name).Fit(data).Transform(data);
}

public static (IDataView testData, IDataView validationData) TestValidateSplit(this TrainCatalogBase catalog, IDataView trainData)
public static (IDataView testData, IDataView validationData) TestValidateSplit(this TrainCatalogBase catalog,
MLContext context, IDataView trainData)
{
IDataView validationData;
(trainData, validationData) = catalog.TrainTestSplit(trainData);
trainData = trainData.DropLastColumn();
validationData = validationData.DropLastColumn();
trainData = trainData.DropLastColumn(context);
validationData = validationData.DropLastColumn(context);
return (trainData, validationData);
}

public static IDataView Skip(this IDataView data, int count)
public static IDataView Skip(this IDataView data, MLContext context, int count)
{
var context = new MLContext();
return SkipFilter.Create(context, data, count);
}

public static (string, ColumnType, ColumnPurpose, ColumnDimensions)[] GetColumnInfoTuples(MLContext context,
IDataView data, ColumnInformation columnInfo)
{
var purposes = PurposeInference.InferPurposes(context, data, columnInfo);
var colDimensions = DatasetDimensionsApi.CalcColumnDimensions(data, purposes);
var colDimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, purposes);
var cols = new (string, ColumnType, ColumnPurpose, ColumnDimensions)[data.Schema.Count];
for (var i = 0; i < cols.Length; i++)
{
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
{
var sample = TextFileSample.CreateFromFullFile(path);
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null);

// if no column is named label,
Expand All @@ -32,7 +32,7 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
{
var sample = TextFileSample.CreateFromFullFile(path);
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, label);
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
}
Expand Down Expand Up @@ -93,10 +93,10 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
};
}

private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
private static TextFileContents.ColumnSplitResult InferSplit(MLContext context, TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
{
var separatorCandidates = separatorChar == null ? TextFileContents.DefaultSeparators : new char[] { separatorChar.Value };
var splitInference = TextFileContents.TrySplitColumns(sample, separatorCandidates);
var splitInference = TextFileContents.TrySplitColumns(context, sample, separatorCandidates);

// respect passed-in overrides
if (allowQuotedStrings != null)
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,12 @@ private static IEnumerable<ITypeInferenceExpert> GetExperts()
/// <summary>
/// Auto-detect column types of the file.
/// </summary>
public static InferenceResult InferTextFileColumnTypes(MLContext env, IMultiStreamSource fileSource, Arguments args)
public static InferenceResult InferTextFileColumnTypes(MLContext context, IMultiStreamSource fileSource, Arguments args)
{
return InferTextFileColumnTypesCore(env, fileSource, args);
return InferTextFileColumnTypesCore(context, fileSource, args);
}

private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMultiStreamSource fileSource, Arguments args)
private static InferenceResult InferTextFileColumnTypesCore(MLContext context, IMultiStreamSource fileSource, Arguments args)
{
if (args.ColumnCount == 0)
{
Expand All @@ -263,9 +263,9 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
AllowSparse = args.AllowSparse,
AllowQuoting = args.AllowQuote,
};
var textLoader = new TextLoader(env, textLoaderArgs);
var textLoader = new TextLoader(context, textLoaderArgs);
var idv = textLoader.Read(fileSource);
idv = idv.Take(args.MaxRowsToRead);
idv = idv.Take(context, args.MaxRowsToRead);

// read all the data into memory.
// list items are rows of the dataset.
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ private static IEnumerable<IPurposeInferenceExpert> GetExperts()
public static PurposeInference.Column[] InferPurposes(MLContext context, IDataView data,
ColumnInformation columnInfo)
{
data = data.Take(MaxRowsToRead);
data = data.Take(context, MaxRowsToRead);

var allColumns = new List<IntermediateColumn>();
var columnsToInfer = new List<IntermediateColumn>();
Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public ColumnSplitResult(bool isSuccess, char? separator, bool allowQuote, bool
/// and this number of columns is more than 1.
/// We sweep on separator, allow sparse and allow quote parameter.
/// </summary>
public static ColumnSplitResult TrySplitColumns(IMultiStreamSource source, char[] separatorCandidates)
public static ColumnSplitResult TrySplitColumns(MLContext context, IMultiStreamSource source, char[] separatorCandidates)
{
var sparse = new[] { true, false };
var quote = new[] { true, false };
Expand All @@ -69,7 +69,7 @@ from _sep in separatorCandidates
AllowSparse = perm._allowSparse
};

if (TryParseFile(args, source, out result))
if (TryParseFile(context, args, source, out result))
{
foundAny = true;
break;
Expand All @@ -78,15 +78,16 @@ from _sep in separatorCandidates
return foundAny ? result : new ColumnSplitResult(false, null, true, true, 0);
}

private static bool TryParseFile(TextLoader.Arguments args, IMultiStreamSource source, out ColumnSplitResult result)
private static bool TryParseFile(MLContext context, TextLoader.Arguments args, IMultiStreamSource source,
out ColumnSplitResult result)
{
result = null;
// try to instantiate data view with swept arguments
try
{

var textLoader = new TextLoader(new MLContext(), args, source);
var idv = textLoader.Read(source).Take(1000);
var textLoader = new TextLoader(context, args, source);
var idv = textLoader.Read(source).Take(context, 1000);
var columnCounts = new List<int>();
var column = idv.Schema["C"];
var columnIndex = column.Index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ internal class DatasetDimensionsApi
{
private const int MaxRowsToRead = 1000;

public static ColumnDimensions[] CalcColumnDimensions(IDataView data, PurposeInference.Column[] purposes)
public static ColumnDimensions[] CalcColumnDimensions(MLContext context, IDataView data, PurposeInference.Column[] purposes)
{
data = data.Take(MaxRowsToRead);
data = data.Take(context, MaxRowsToRead);

var colDimensions = new ColumnDimensions[data.Schema.Count];

Expand Down
27 changes: 14 additions & 13 deletions src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,32 @@ internal static class PipelineSuggester
{
private const int TopKTrainers = 3;

public static Pipeline GetNextPipeline(IEnumerable<PipelineScore> history,
public static Pipeline GetNextPipeline(MLContext context,
IEnumerable<PipelineScore> history,
(string, ColumnType, ColumnPurpose, ColumnDimensions)[] columns,
TaskKind task,
bool isMaximizingMetric = true)
{
var inferredHistory = history.Select(r => SuggestedPipelineResult.FromPipelineRunResult(r));
var nextInferredPipeline = GetNextInferredPipeline(inferredHistory, columns, task, isMaximizingMetric);
var inferredHistory = history.Select(r => SuggestedPipelineResult.FromPipelineRunResult(context, r));
var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric);
return nextInferredPipeline?.ToPipeline();
}

public static SuggestedPipeline GetNextInferredPipeline(IEnumerable<SuggestedPipelineResult> history,
public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
IEnumerable<SuggestedPipelineResult> history,
(string, ColumnType, ColumnPurpose, ColumnDimensions)[] columns,
TaskKind task,
bool isMaximizingMetric,
IEnumerable<TrainerName> trainerWhitelist = null)
{
var context = new MLContext();

var availableTrainers = RecipeInference.AllowedTrainers(context, task, trainerWhitelist);
var transforms = CalculateTransforms(context, columns, task);
//var transforms = TransformInferenceApi.InferTransforms(context, columns, task);

// if we haven't run all pipelines once
if (history.Count() < availableTrainers.Count())
{
return GetNextFirstStagePipeline(history, availableTrainers, transforms);
return GetNextFirstStagePipeline(context, history, availableTrainers, transforms);
}

// get top trainers from stage 1 runs
Expand All @@ -63,14 +63,14 @@ public static SuggestedPipeline GetNextInferredPipeline(IEnumerable<SuggestedPip
do
{
// sample new hyperparameters for the learner
if (!SampleHyperparameters(newTrainer, history, isMaximizingMetric))
if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric))
{
// if unable to sample new hyperparameters for the learner
// (ie SMAC returned 0 suggestions), break
break;
}

var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer);
var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer, context);

// make sure we have not seen pipeline before
if (!visitedPipelines.Contains(suggestedPipeline))
Expand Down Expand Up @@ -113,12 +113,13 @@ private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerabl
.Select(x => x.First().Pipeline.Trainer);
}

private static SuggestedPipeline GetNextFirstStagePipeline(IEnumerable<SuggestedPipelineResult> history,
private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context,
IEnumerable<SuggestedPipelineResult> history,
IEnumerable<SuggestedTrainer> availableTrainers,
IEnumerable<SuggestedTransform> transforms)
{
var trainer = availableTrainers.ElementAt(history.Count());
return new SuggestedPipeline(transforms, trainer);
return new SuggestedPipeline(transforms, trainer, context);
}

private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableParam> hps)
Expand Down Expand Up @@ -184,10 +185,10 @@ private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableP
/// Samples new hyperparameters for the trainer, and sets them.
/// Returns true if success (new hyperparams were suggested and set). Else, returns false.
/// </summary>
private static bool SampleHyperparameters(SuggestedTrainer trainer, IEnumerable<SuggestedPipelineResult> history, bool isMaximizingMetric)
private static bool SampleHyperparameters(MLContext context, SuggestedTrainer trainer, IEnumerable<SuggestedPipelineResult> history, bool isMaximizingMetric)
{
var sps = ConvertToValueGenerators(trainer.SweepParams);
var sweeper = new SmacSweeper(
var sweeper = new SmacSweeper(context,
new SmacSweeper.Arguments
{
SweptParameters = sps
Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Auto/Sweepers/SmacSweeper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ public sealed class Arguments

private readonly ISweeper _randomSweeper;
private readonly Arguments _args;
private readonly MLContext _context = new MLContext();
private readonly MLContext _context;

private readonly IValueGenerator[] _sweepParameters;

public SmacSweeper(Arguments args)
public SmacSweeper(MLContext context, Arguments args)
{
_context = context;
_args = args;
_sweepParameters = args.SweptParameters;
_randomSweeper = new UniformRandomSweeper(new SweeperBase.ArgumentsBase(), _sweepParameters);
Expand Down
Loading

0 comments on commit 2d8a6ed

Please sign in to comment.