Skip to content

Commit

Permalink
Use Timer and ctx.CancelExecution() to fix AutoML max-time experiment…
Browse files Browse the repository at this point in the history
… bug (#5445)

* Use ctx.CalncelExecution() to fix AutoML max-time experiment bug

* Added unit test for checking canceled experiment

* Nit fix

* Different run time on Linux

* Review

* Testing four ouput

* Used reflection to test for contexts being canceled

* Reviews

* Reviews

* Added main MLContext listener-timer

* Added PRNG on _context, held onto timers for avoiding GC

* Addressed reviews

* Unit test edits

* Increase run time of experiment to guarantee probabilities

* Edited unit test to check produced schema of next run model's predictions

* Remove scheme check as different CI builds result in varying schemas

* Decrease max experiment time unit test time

* Added Timers

* Increase second timer time, edit unit test

* Added try catch for OperationCanceledException in Execute()

* Add AggregateException try catch to slow unit tests for parallel testing

* Reviews

* Final reviews

* Added LightGBMFact to binary classification test

* Removed extra Operation Stopped exception try catch

* Add back OperationCanceledException to Experiment.cs
  • Loading branch information
mstfbl authored Nov 3, 2020
1 parent a0e959c commit 600d48d
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 74 deletions.
152 changes: 112 additions & 40 deletions src/Microsoft.ML.AutoML/Experiment/Experiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.AutoML
Expand All @@ -25,6 +27,11 @@ internal class Experiment<TRunDetail, TMetrics> where TRunDetail : RunDetail
private readonly IRunner<TRunDetail> _runner;
private readonly IList<SuggestedPipelineRunDetail> _history;
private readonly IChannel _logger;
private Timer _maxExperimentTimeTimer;
private Timer _mainContextCanceledTimer;
private bool _experimentTimerExpired;
private MLContext _currentModelMLContext;
private Random _newContextSeedGenerator;

public Experiment(MLContext context,
TaskKind task,
Expand All @@ -49,60 +56,125 @@ public Experiment(MLContext context,
_datasetColumnInfo = datasetColumnInfo;
_runner = runner;
_logger = logger;
_experimentTimerExpired = false;
}

private void MaxExperimentTimeExpiredEvent(object state)
{
// If at least one model was run, end experiment immediately.
// Else, wait for first model to run before experiment is concluded.
_experimentTimerExpired = true;
if (_history.Any(r => r.RunSucceeded))
{
_logger.Warning("Allocated time for Experiment of {0} seconds has elapsed with {1} models run. Ending experiment...",
_experimentSettings.MaxExperimentTimeInSeconds, _history.Count());
_currentModelMLContext.CancelExecution();
}
}

private void MainContextCanceledEvent(object state)
{
// If the main MLContext is canceled, cancel the ongoing model training and MLContext.
if ((_context.Model.GetEnvironment() as ICancelable).IsCanceled)
{
_logger.Warning("Main MLContext has been canceled. Ending experiment...");
// Stop timer to prevent restarting and prevent continuous calls to
// MainContextCanceledEvent
_mainContextCanceledTimer.Change(Timeout.Infinite, Timeout.Infinite);
_currentModelMLContext.CancelExecution();
}
}

public IList<TRunDetail> Execute()
{
var stopwatch = Stopwatch.StartNew();
var iterationResults = new List<TRunDetail>();
// Create a timer for the max duration of experiment. When given time has
// elapsed, MaxExperimentTimeExpiredEvent is called to interrupt training
// of current model. Timer is not used if no experiment time is given, or
// is not a positive number.
if (_experimentSettings.MaxExperimentTimeInSeconds > 0)
{
_maxExperimentTimeTimer = new Timer(
new TimerCallback(MaxExperimentTimeExpiredEvent), null,
_experimentSettings.MaxExperimentTimeInSeconds * 1000, Timeout.Infinite
);
}
// If given max duration of experiment is 0, only 1 model will be trained.
// _experimentSettings.MaxExperimentTimeInSeconds is of type uint, it is
// either 0 or >0.
else
_experimentTimerExpired = true;

// Add second timer to check for the cancelation signal from the main MLContext
// to the active child MLContext. This timer will propagate the cancelation
// signal from the main to the child MLContexs if the main MLContext is
// canceled.
_mainContextCanceledTimer = new Timer(new TimerCallback(MainContextCanceledEvent), null, 1000, 1000);

// Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to
// maintain variability between training iterations.
int? mainContextSeed = ((ISeededEnvironment)_context.Model.GetEnvironment()).Seed;
_newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : null;

do
{
var iterationStopwatch = Stopwatch.StartNew();

// get next pipeline
var getPipelineStopwatch = Stopwatch.StartNew();
var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task,
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList);

var pipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;

// break if no candidates returned, means no valid pipeline available
if (pipeline == null)
{
break;
}

// evaluate pipeline
_logger.Trace($"Evaluating pipeline {pipeline.ToString()}");
(SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail)
= _runner.Run(pipeline, _modelDirectory, _history.Count + 1);

_history.Add(suggestedPipelineRunDetail);
WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch);

runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds;
runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;

ReportProgress(runDetail);
iterationResults.Add(runDetail);

// if model is perfect, break
if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score))
try
{
break;
var iterationStopwatch = Stopwatch.StartNew();

// get next pipeline
var getPipelineStopwatch = Stopwatch.StartNew();

// A new MLContext is needed per model run. When max experiment time is reached, each used
// context is canceled to stop further model training. The cancellation of the main MLContext
// a user has instantiated is not desirable, thus additional MLContexts are used.
_currentModelMLContext = _newContextSeedGenerator == null ? new MLContext() : new MLContext(_newContextSeedGenerator.Next());
var pipeline = PipelineSuggester.GetNextInferredPipeline(_currentModelMLContext, _history, _datasetColumnInfo, _task,
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList);
// break if no candidates returned, means no valid pipeline available
if (pipeline == null)
{
break;
}

// evaluate pipeline
_logger.Trace($"Evaluating pipeline {pipeline.ToString()}");
(SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail)
= _runner.Run(pipeline, _modelDirectory, _history.Count + 1);

_history.Add(suggestedPipelineRunDetail);
WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch);

runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds;
runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;

ReportProgress(runDetail);
iterationResults.Add(runDetail);

// if model is perfect, break
if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score))
{
break;
}

// If after third run, all runs have failed so far, throw exception
if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded))
{
throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}");
}
}

// If after third run, all runs have failed so far, throw exception
if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded))
catch (OperationCanceledException e)
{
throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}");
// This exception is thrown when the IHost/MLContext of the trainer is canceled due to
// reaching maximum experiment time. Simply catch this exception and return finished
// iteration results.
_logger.Warning("OperationCanceledException has been caught after maximum experiment time" +
"was reached, and the running MLContext was stopped. Details: {0}", e.Message);
return iterationResults;
}

} while (_history.Count < _experimentSettings.MaxModels &&
!_experimentSettings.CancellationToken.IsCancellationRequested &&
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds);

!_experimentTimerExpired);
return iterationResults;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public CrossValSummaryRunner(MLContext context,
for (var i = 0; i < _trainDatasets.Length; i++)
{
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
var trainResult = RunnerUtil.TrainAndScorePipeline(pipeline.GetContext(), pipeline, _trainDatasets[i], _validDatasets[i],
_groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
_logger);
trainResults.Add(trainResult);
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.ML.AutoML/Experiment/SuggestedPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public override int GetHashCode()
return ToString().GetHashCode();
}

public MLContext GetContext()
{
return _context;
}

public Pipeline ToPipeline()
{
var pipelineElements = new List<PipelineNode>();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
internal interface ICancelable
{
/// <summary>
/// Signal to stop exection in all the hosts.
/// Signal to stop execution in all the hosts.
/// </summary>
void CancelExecution();

Expand Down
Loading

0 comments on commit 600d48d

Please sign in to comment.