Skip to content

Commit

Permalink
Tensorflow fixes from PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp committed May 12, 2021
1 parent ce99388 commit 35ef427
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Vision/DnnRetrainTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private static DnnRetrainTransformer Create(IHostEnvironment env, ModelLoadConte
null, false, addBatchDimensionInput, 1);
}

var tempDirPath = Path.GetFullPath(Path.Combine((env as IHostEnvironmentInternal).TempFilePath, nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
var tempDirPath = Path.GetFullPath(Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
CreateFolderWithAclIfNotExists(env, tempDirPath);
try
{
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
Host.CheckNonEmpty(options.ScoreColumnName, nameof(options.ScoreColumnName));
Host.CheckNonEmpty(options.PredictedLabelColumnName, nameof(options.PredictedLabelColumnName));
tf.compat.v1.disable_eager_execution();
_resourcePath = Path.Combine((env as IHostEnvironmentInternal).TempFilePath, "MLNET");
_resourcePath = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, "MLNET");

if (string.IsNullOrEmpty(options.WorkspacePath))
{
Expand Down Expand Up @@ -1320,12 +1320,12 @@ private void AddTransferLearningLayer(string labelColumn,

}

private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch)
private TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch)
{
var modelFileName = ModelFileName[arch];
var modelFilePath = Path.Combine((env as IHostEnvironmentInternal).TempFilePath, "MLNET", modelFileName);
var modelFilePath = Path.Combine(_resourcePath, "MLNET", modelFileName);
int timeout = 10 * 60 * 1000;
DownloadIfNeeded(env, @"meta\" + modelFileName, Path.Combine((env as IHostEnvironmentInternal).TempFilePath, "MLNET"), modelFileName, timeout);
DownloadIfNeeded(env, @"meta\" + modelFileName, Path.Combine(_resourcePath, "MLNET"), modelFileName, timeout);
return new TensorFlowSessionWrapper(GetSession(env, modelFilePath, true), modelFilePath);
}

Expand Down

0 comments on commit 35ef427

Please sign in to comment.