Skip to content

Commit

Permalink
Added unit test for save/load TensorFlow SavedModel; issue (#5797).
Browse files Browse the repository at this point in the history
Signed-off-by: darth-vader-lg <luigi.generale@gmail.com>
  • Loading branch information
darth-vader-lg committed May 29, 2021
1 parent 2ee9aa4 commit ce683fc
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 26 deletions.
52 changes: 26 additions & 26 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -463,32 +463,32 @@ private protected override void SaveModel(ModelSaveContext ctx)
}
}
else {
ctx.SaveBinaryStream("TFSavedModel", w =>
{
// only these files need to be saved.
var modelFilePaths = new List<string>
{
Path.Combine(_savedModelPath, DefaultModelFileNames.Graph),
Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index)
};
modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly));

w.Write(modelFilePaths.Count);

foreach (var fullPath in modelFilePaths)
{
var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
w.Write(relativePath);

using (var fs = new FileStream(fullPath, FileMode.Open))
{
long fileLength = fs.Length;
w.Write(fileLength);
long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
Host.Assert(actualWritten == fileLength);
}
}
});
ctx.SaveBinaryStream("TFSavedModel", w =>
{
// only these files need to be saved.
var modelFilePaths = new List<string>
{
Path.Combine(_savedModelPath, DefaultModelFileNames.Graph),
Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index)
};
modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly));

w.Write(modelFilePaths.Count);

foreach (var fullPath in modelFilePaths)
{
var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
w.Write(relativePath);

using (var fs = new FileStream(fullPath, FileMode.Open))
{
long fileLength = fs.Length;
w.Write(fileLength);
long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
Host.Assert(actualWritten == fileLength);
}
}
});
}

Host.AssertNonEmpty(Inputs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,79 @@ public void TensorFlowGettingSchemaMultipleTimes()
}
}

// This test has been created as result of https://github.com/dotnet/machinelearning/issues/5797.
[TensorFlowFact]
public void TensorFlowSaveAndLoadSavedModel()
{
// Create the model and do some predictions
var imageHeight = 32;
var imageWidth = 32;
var modelLocation = "cifar_saved_model";
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);

var data = TextLoader.Create(_mlContext, new TextLoader.Options()
{
Columns = new[]
{
new TextLoader.Column("ImagePath", DataKind.String, 0),
new TextLoader.Column("Label", DataKind.String, 1),
}
}, new MultiFileSource(dataFile));

var pipeEstimator = new ImageLoadingEstimator(_mlContext, imageFolder, ("ImageReal", "ImagePath"))
.Append(new ImageResizingEstimator(_mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal"))
.Append(new ImagePixelExtractingEstimator(_mlContext, "Input", "ImageCropped", interleavePixelColors: true))
.Append(_mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input"))
.Append(new ColumnConcatenatingEstimator(_mlContext, "Features", "Output"))
.Append(new ValueToKeyMappingEstimator(_mlContext, "Label"))
.AppendCacheCheckpoint(_mlContext)
.Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes());


using var transformer = pipeEstimator.Fit(data);
var transformedData = transformer.Transform(data);
var outputSchema = transformer.GetOutputSchema(data.Schema);

var metrics = _mlContext.MulticlassClassification.Evaluate(transformedData);
Assert.Equal(1, metrics.MicroAccuracy, 2);

var predictFunction = _mlContext.Model.CreatePredictionEngine<CifarData, CifarPrediction>(transformer);
var predictions = new[]
{
predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }),
predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }),
predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") })
};

// Save the model as a standard ML.NET zip repo
var mlModelLocation = DeleteOutputPath(Path.ChangeExtension(modelLocation, ".zip"));
_mlContext.Model.Save(transformer, data.Schema, mlModelLocation);
transformer.Dispose();
predictFunction.Dispose();

// Reload the model and check the output schema consistency
DataViewSchema loadedInputschema;
var testTransformer = _mlContext.Model.Load(mlModelLocation, out loadedInputschema);
var testOutputSchema = transformer.GetOutputSchema(data.Schema);
Assert.True(TestCommon.CheckSameSchemas(outputSchema, testOutputSchema));

// Repeat the predictions with the model loaded as zip repo
var testPredictFunction = _mlContext.Model.CreatePredictionEngine<CifarData, CifarPrediction>(testTransformer);
var testPredictions = new[]
{
testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }),
testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }),
testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") })
};

// Check the predictions consistency
for (var i = 0; i < predictions.Length; i++) {
for (var j = 0; j < predictions[i].PredictedScores.Length; j++)
Assert.Equal(predictions[i].PredictedScores[j], testPredictions[i].PredictedScores[j], 2);
}
}

[TensorFlowFact]
public void TensorFlowTransformCifarInvalidShape()
{
Expand Down

0 comments on commit ce683fc

Please sign in to comment.