From ce683fcbecb436a29b9a8415cd3f23ce200f3fde Mon Sep 17 00:00:00 2001 From: darth-vader-lg <luigi.generale@gmail.com> Date: Sat, 29 May 2021 10:42:15 +0200 Subject: [PATCH] Added unit test for save/load TensorFlow SavedModel; issue (#5797). Signed-off-by: darth-vader-lg <luigi.generale@gmail.com> --- .../TensorflowTransform.cs | 52 ++++++------- .../TensorflowTests.cs | 73 +++++++++++++++++++ 2 files changed, 99 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index d3af9a799c..80b9638f77 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -463,32 +463,32 @@ private protected override void SaveModel(ModelSaveContext ctx) } } else { - ctx.SaveBinaryStream("TFSavedModel", w => - { - // only these files need to be saved. - var modelFilePaths = new List<string> - { - Path.Combine(_savedModelPath, DefaultModelFileNames.Graph), - Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index) - }; - modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly)); - - w.Write(modelFilePaths.Count); - - foreach (var fullPath in modelFilePaths) - { - var relativePath = fullPath.Substring(_savedModelPath.Length + 1); - w.Write(relativePath); - - using (var fs = new FileStream(fullPath, FileMode.Open)) - { - long fileLength = fs.Length; - w.Write(fileLength); - long actualWritten = fs.CopyRange(w.BaseStream, fileLength); - Host.Assert(actualWritten == fileLength); - } - } - }); + ctx.SaveBinaryStream("TFSavedModel", w => + { + // only these files need to be saved. + var modelFilePaths = new List<string> + { + Path.Combine(_savedModelPath, DefaultModelFileNames.Graph), + Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index) + }; + modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly)); + + w.Write(modelFilePaths.Count); + + foreach (var fullPath in modelFilePaths) + { + var relativePath = fullPath.Substring(_savedModelPath.Length + 1); + w.Write(relativePath); + + using (var fs = new FileStream(fullPath, FileMode.Open)) + { + long fileLength = fs.Length; + w.Write(fileLength); + long actualWritten = fs.CopyRange(w.BaseStream, fileLength); + Host.Assert(actualWritten == fileLength); + } + } + }); } Host.AssertNonEmpty(Inputs); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 3508cfe481..30861521f9 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1152,6 +1152,79 @@ public void TensorFlowGettingSchemaMultipleTimes() } } + // This test has been created as result of https://github.com/dotnet/machinelearning/issues/5797. + [TensorFlowFact] + public void TensorFlowSaveAndLoadSavedModel() + { + // Create the model and do some predictions + var imageHeight = 32; + var imageWidth = 32; + var modelLocation = "cifar_saved_model"; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + + var data = TextLoader.Create(_mlContext, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("ImagePath", DataKind.String, 0), + new TextLoader.Column("Label", DataKind.String, 1), + } + }, new MultiFileSource(dataFile)); + + var pipeEstimator = new ImageLoadingEstimator(_mlContext, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(_mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(_mlContext, "Input", "ImageCropped", interleavePixelColors: true)) + .Append(_mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input")) + .Append(new ColumnConcatenatingEstimator(_mlContext, "Features", "Output")) + .Append(new ValueToKeyMappingEstimator(_mlContext, "Label")) + .AppendCacheCheckpoint(_mlContext) + .Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes()); + + + using var transformer = pipeEstimator.Fit(data); + var transformedData = transformer.Transform(data); + var outputSchema = transformer.GetOutputSchema(data.Schema); + + var metrics = _mlContext.MulticlassClassification.Evaluate(transformedData); + Assert.Equal(1, metrics.MicroAccuracy, 2); + + var predictFunction = _mlContext.Model.CreatePredictionEngine<CifarData, CifarPrediction>(transformer); + var predictions = new[] + { + predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }), + predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }), + predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") }) + }; + + // Save the model as a standard ML.NET zip repo + var mlModelLocation = DeleteOutputPath(Path.ChangeExtension(modelLocation, ".zip")); + _mlContext.Model.Save(transformer, data.Schema, mlModelLocation); + transformer.Dispose(); + predictFunction.Dispose(); + + // Reload the model and check the output schema consistency + DataViewSchema loadedInputschema; + var testTransformer = _mlContext.Model.Load(mlModelLocation, out loadedInputschema); + var testOutputSchema = transformer.GetOutputSchema(data.Schema); + Assert.True(TestCommon.CheckSameSchemas(outputSchema, testOutputSchema)); + + // Repeat the predictions with the model loaded as zip repo + var testPredictFunction = _mlContext.Model.CreatePredictionEngine<CifarData, CifarPrediction>(testTransformer); + var testPredictions = new[] + { + testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }), + testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }), + testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") }) + }; + + // Check the predictions consistency + for (var i = 0; i < predictions.Length; i++) { + for (var j = 0; j < predictions[i].PredictedScores.Length; j++) + Assert.Equal(predictions[i].PredictedScores[j], testPredictions[i].PredictedScores[j], 2); + } + } + [TensorFlowFact] public void TensorFlowTransformCifarInvalidShape() {