Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx load model #5782

Merged
merged 9 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/Microsoft.ML.AutoML/API/ExperimentSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ public abstract class ExperimentSettings
public CancellationToken CancellationToken { get; set; }

/// <summary>
/// This is a pointer to a directory where all models trained during the AutoML experiment will be saved.
/// This is the name of the directory where all models trained during the AutoML experiment will be saved.
/// If <see langword="null"/>, models will be kept in memory instead of written to disk.
/// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in
/// memory could cause a system to run out of memory.)
/// </summary>
/// <value>The default value is the directory named "Microsoft.ML.AutoML" in the current user's temporary folder.</value>
public DirectoryInfo CacheDirectory { get; set; }
/// <value>The default value is the directory named "Microsoft.ML.AutoML" in the in the location specified by the <see cref="MLContext.TempFilePath"/>.</value>
public string CacheDirectoryName { get; set; }

/// <summary>
/// Whether AutoML should cache before ML.NET trainers.
Expand All @@ -66,10 +66,11 @@ public ExperimentSettings()
{
MaxExperimentTimeInSeconds = 24 * 60 * 60;
CancellationToken = default;
CacheDirectory = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.AutoML"));
CacheDirectoryName = "Microsoft.ML.AutoML";
CacheBeforeTrainer = CacheBeforeTrainer.Auto;
MaxModels = int.MaxValue;
}

}

/// <summary>
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.AutoML/API/RankingExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// </value>
public uint OptimizationMetricTruncationLevel { get; set; }

/// <summary>
/// Initializes a new instance of <see cref="RankingExperimentSettings"/>.
/// </summary>
public RankingExperimentSettings()
{
OptimizingMetric = RankingMetric.Ndcg;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public sealed class RegressionExperimentSettings : ExperimentSettings
/// </value>
public ICollection<RegressionTrainer> Trainers { get; }

/// <summary>
/// Initializes a new instance of <see cref="RegressionExperimentSettings"/>.
/// </summary>
public RegressionExperimentSettings()
{
OptimizingMetric = RegressionMetric.RSquared;
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.AutoML/Experiment/Experiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public Experiment(MLContext context,
_experimentSettings = experimentSettings;
_metricsAgent = metricsAgent;
_trainerAllowList = trainerAllowList;
_modelDirectory = GetModelDirectory(_experimentSettings.CacheDirectory);
_modelDirectory = GetModelDirectory(_context.TempFilePath, _experimentSettings.CacheDirectoryName);
_datasetColumnInfo = datasetColumnInfo;
_runner = runner;
_logger = logger;
Expand Down Expand Up @@ -140,7 +140,7 @@ public IList<TRunDetail> Execute()

// Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to
// maintain variability between training iterations.
int? mainContextSeed = ((ISeededEnvironment)_context.Model.GetEnvironment()).Seed;
int? mainContextSeed = ((IHostEnvironmentInternal)_context.Model.GetEnvironment()).Seed;
_newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : null;

do
Expand Down Expand Up @@ -220,14 +220,14 @@ public IList<TRunDetail> Execute()
return iterationResults;
}

private static DirectoryInfo GetModelDirectory(DirectoryInfo rootDir)
private static DirectoryInfo GetModelDirectory(string tempDirectory, string cacheDirectoryName)
{
if (rootDir == null)
if (cacheDirectoryName == null)
{
return null;
}

var experimentDirFullPath = Path.Combine(rootDir.FullName, $"experiment_{Path.GetRandomFileName()}");
var experimentDirFullPath = Path.Combine(tempDirectory, cacheDirectoryName, $"experiment_{Path.GetRandomFileName()}");
var experimentDirInfo = new DirectoryInfo(experimentDirFullPath);
if (!experimentDirInfo.Exists)
{
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,17 @@ internal interface ICancelable
}

[BestFriend]
internal interface ISeededEnvironment : IHostEnvironment
internal interface IHostEnvironmentInternal : IHostEnvironment
{
/// <summary>
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
/// </summary>
int? Seed { get; }

/// <summary>
/// The location for the temp files created by ML.NET
/// </summary>
string TempFilePath { get; set; }
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal interface IMessageSource
/// query progress.
/// </summary>
[BestFriend]
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, ISeededEnvironment, IChannelProvider, ICancelable
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironmentInternal, IChannelProvider, ICancelable
where TEnv : HostEnvironmentBase<TEnv>
{
void ICancelable.CancelExecution()
Expand Down Expand Up @@ -326,6 +326,10 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
}
}

#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value incase the user doesn't set it.
public string TempFilePath { get; set; } = System.IO.Path.GetTempPath();
#pragma warning restore MSML_NoInstanceInitializers

protected readonly TEnv Root;
// This is non-null iff this environment was a fork of another. Disposing a fork
// doesn't free temp files. That is handled when the master is disposed.
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ private void RunCore(IChannel ch, string cmd)

ILegacyDataLoader testPipe;
bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile);
var tempFilePath = hasOutfile ? null : Path.GetTempFileName();
var tempFilePath = hasOutfile ? null : Path.Combine(((IHostEnvironmentInternal)Host).TempFilePath, Path.GetRandomFileName());

using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView dat
}
else if(fallbackInEnvSeed)
{
ISeededEnvironment seededEnv = (ISeededEnvironment)env;
IHostEnvironmentInternal seededEnv = (IHostEnvironmentInternal)env;
seedToUse = seededEnv.Seed;
}
else
Expand Down
13 changes: 11 additions & 2 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML
/// create components for data preparation, feature enginering, training, prediction, model evaluation.
/// It also allows logging, execution control, and the ability set repeatable random numbers.
/// </summary>
public sealed class MLContext : ISeededEnvironment
public sealed class MLContext : IHostEnvironmentInternal
{
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
private readonly LocalEnvironment _env;
Expand Down Expand Up @@ -79,6 +79,15 @@ public sealed class MLContext : ISeededEnvironment
/// </summary>
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;

/// <summary>
/// Gets or sets the location for the temp files created by ML.NET.
/// </summary>
public string TempFilePath
{
get { return _env.TempFilePath; }
set { _env.TempFilePath = value; }
}
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Create the ML context.
/// </summary>
Expand Down Expand Up @@ -140,7 +149,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
int? ISeededEnvironment.Seed => _env.Seed;
int? IHostEnvironmentInternal.Seed => _env.Seed;

[BestFriend]
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
{
// Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
// to create a temporal file to store it and then call ONNXRuntime's API to load that file.
Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu, shapeDictionary: shapeDictionary);
Model = OnnxModel.CreateFromBytes(modelBytes, env, options.GpuDeviceId, options.FallbackToCpu, shapeDictionary: shapeDictionary);
}
}
catch (OnnxRuntimeException e)
Expand Down Expand Up @@ -304,7 +304,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelFile)); });
ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelStream.Name)); });
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved

Host.CheckNonEmpty(Inputs, nameof(Inputs));
ctx.Writer.Write(Inputs.Length);
Expand Down
47 changes: 23 additions & 24 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,9 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
/// </summary>
private readonly InferenceSession _session;
/// <summary>
/// Indicates if <see cref="ModelFile"/> is a temporal file created by <see cref="CreateFromBytes(byte[], int?, bool, IDictionary{string, int[]})"/>
/// or <see cref="CreateFromBytes(byte[])"/>. If <see langword="true"/>, <see cref="Dispose(bool)"/> should delete <see cref="ModelFile"/>.
/// The FileStream holding onto the loaded ONNX model.
/// </summary>
private bool _ownModelFile;
/// <summary>
/// The location where the used ONNX model loaded from.
/// </summary>
internal string ModelFile { get; }
internal FileStream ModelStream { get; }
/// <summary>
/// The ONNX model's information from ONNXRuntime's perspective. ML.NET can change the input and output of that model in some ways.
/// For example, ML.NET can shuffle the inputs so that the i-th ONNX input becomes the j-th input column of <see cref="OnnxTransformer"/>.
Expand All @@ -172,9 +167,7 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
bool ownModelFile=false, IDictionary<string, int[]> shapeDictionary = null)
{
ModelFile = modelFile;
// If we don't own the model file, _disposed should be false to prevent deleting user's file.
_ownModelFile = ownModelFile;
_disposed = false;

if (gpuDeviceId != null)
Expand Down Expand Up @@ -202,9 +195,15 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
{
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
// doesn't expose full type information via its C# APIs.
ModelFile = modelFile;
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(modelFile))
// If we own the model file set the DeleteOnClose flag so it is always deleted.
if (ownModelFile)
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose);
else
ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read);

// The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here.
using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read))
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

Expand Down Expand Up @@ -322,33 +321,35 @@ private static bool CheckOnnxShapeCompatibility(IEnumerable<int> left, IEnumerab

/// <summary>
/// Create an OnnxModel from a byte[]. Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
/// With <see cref="CreateFromBytes(byte[])"/> and <see cref="CreateFromBytes(byte[], int?, bool, IDictionary{string, int[]})"/>,
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]})"/>,
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
/// </summary>
/// <param name="modelBytes">Bytes of the serialized model</param>
public static OnnxModel CreateFromBytes(byte[] modelBytes)
/// <param name="env">IHostEnvironment</param>
public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env)
{
return CreateFromBytes(modelBytes, null, false);
return CreateFromBytes(modelBytes, env, null, false);
}

/// <summary>
/// Create an OnnxModel from a byte[]. Set execution to GPU if required.
/// Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
/// With <see cref="CreateFromBytes(byte[])"/> and
/// <see cref="CreateFromBytes(byte[], int?, bool, IDictionary{string, int[]})"/>,
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and
/// <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]})"/>,
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
/// </summary>
/// <param name="modelBytes">Bytes of the serialized model.</param>
/// <param name="env">IHostEnvironment</param>
/// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
/// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param>
/// <param name="shapeDictionary">User-provided shapes. If the key "myTensorName" is associated
/// with the value [1, 3, 5], the shape of "myTensorName" will be set to [1, 3, 5].
/// The shape loaded from <paramref name="modelBytes"/> would be overwritten.</param>
/// <returns>An <see cref="OnnxModel"/></returns>
public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = null, bool fallbackToCpu = false,
public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env, int? gpuDeviceId = null, bool fallbackToCpu = false,
IDictionary<string, int[]> shapeDictionary = null)
{
var tempModelFile = Path.GetTempFileName();
var tempModelFile = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName());
File.WriteAllBytes(tempModelFile, modelBytes);
return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu,
ownModelFile: true, shapeDictionary: shapeDictionary);
Expand All @@ -366,7 +367,7 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(List<NamedOnn
}

/// <summary>
/// Flag used to indicate if the unmanaged resources (aka the model file <see cref="ModelFile"/>
/// Flag used to indicate if the unmanaged resources (aka the model file handle <see cref="ModelStream"/>
/// and <see cref="_session"/>) have been deleted.
/// </summary>
private bool _disposed;
Expand All @@ -378,8 +379,7 @@ public void Dispose()
}

/// <summary>
/// There are two unmanaged resources we can dispose, <see cref="_session"/> and <see cref="ModelFile"/>
/// if <see cref="_ownModelFile"/> is <see langword="true"/>.
/// There are two unmanaged resources we can dispose, <see cref="_session"/> and <see cref="ModelStream"/>
/// </summary>
/// <param name="disposing"></param>
private void Dispose(bool disposing)
Expand All @@ -391,9 +391,8 @@ private void Dispose(bool disposing)
{
// First, we release the resource token by ONNXRuntime.
_session.Dispose();
// Second, we delete the model file if that file is not created by the user.
if (_ownModelFile && File.Exists(ModelFile))
File.Delete(ModelFile);
// Second, Dispose of the model file stream.
ModelStream.Dispose();
}
_disposed = true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched);
}

var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid()));
var tempDirPath = Path.GetFullPath(Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, nameof(TensorFlowTransformer) + "_" + Guid.NewGuid()));
CreateFolderWithAclIfNotExists(env, tempDirPath);
try
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,9 @@ public void Dispose()
}
}

internal static string GetTemporaryDirectory()
internal static string GetTemporaryDirectory(IHostEnvironment env)
{
string tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName());
string tempDirectory = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName());
Directory.CreateDirectory(tempDirectory);
return tempDirectory;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Vision/DnnRetrainTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private static DnnRetrainTransformer Create(IHostEnvironment env, ModelLoadConte
null, false, addBatchDimensionInput, 1);
}

var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
var tempDirPath = Path.GetFullPath(Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
CreateFolderWithAclIfNotExists(env, tempDirPath);
try
{
Expand Down
Loading