From ce683fcbecb436a29b9a8415cd3f23ce200f3fde Mon Sep 17 00:00:00 2001
From: darth-vader-lg <luigi.generale@gmail.com>
Date: Sat, 29 May 2021 10:42:15 +0200
Subject: [PATCH] Added unit test for save/load TensorFlow SavedModel; issue
 (#5797).

Signed-off-by: darth-vader-lg <luigi.generale@gmail.com>
---
 .../TensorflowTransform.cs                    | 52 ++++++-------
 .../TensorflowTests.cs                        | 73 +++++++++++++++++++
 2 files changed, 99 insertions(+), 26 deletions(-)

diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index d3af9a799c..80b9638f77 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -463,32 +463,32 @@ private protected override void SaveModel(ModelSaveContext ctx)
                 }
             }
             else {
-               ctx.SaveBinaryStream("TFSavedModel", w =>
-               {
-                   // only these files need to be saved.
-                   var modelFilePaths = new List<string>
-                   {
-                       Path.Combine(_savedModelPath, DefaultModelFileNames.Graph),
-                       Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index)
-                   };
-                   modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly));
-
-                   w.Write(modelFilePaths.Count);
-
-                   foreach (var fullPath in modelFilePaths)
-                   {
-                       var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
-                       w.Write(relativePath);
-
-                       using (var fs = new FileStream(fullPath, FileMode.Open))
-                       {
-                           long fileLength = fs.Length;
-                           w.Write(fileLength);
-                           long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
-                           Host.Assert(actualWritten == fileLength);
-                       }
-                   }
-               });
+                ctx.SaveBinaryStream("TFSavedModel", w =>
+                {
+                    // only these files need to be saved.
+                    var modelFilePaths = new List<string>
+                    {
+                        Path.Combine(_savedModelPath, DefaultModelFileNames.Graph),
+                        Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index)
+                    };
+                    modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly));
+
+                    w.Write(modelFilePaths.Count);
+
+                    foreach (var fullPath in modelFilePaths)
+                    {
+                        var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
+                        w.Write(relativePath);
+
+                        using (var fs = new FileStream(fullPath, FileMode.Open))
+                        {
+                            long fileLength = fs.Length;
+                            w.Write(fileLength);
+                            long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
+                            Host.Assert(actualWritten == fileLength);
+                        }
+                    }
+                });
             }
 
             Host.AssertNonEmpty(Inputs);
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 3508cfe481..30861521f9 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -1152,6 +1152,79 @@ public void TensorFlowGettingSchemaMultipleTimes()
             }
         }
 
+        // This test has been created as result of https://github.com/dotnet/machinelearning/issues/5797.
+        [TensorFlowFact]
+        public void TensorFlowSaveAndLoadSavedModel()
+        {
+            // Create the model and do some predictions
+            var imageHeight = 32;
+            var imageWidth = 32;
+            var modelLocation = "cifar_saved_model";
+            var dataFile = GetDataPath("images/images.tsv");
+            var imageFolder = Path.GetDirectoryName(dataFile);
+
+            var data = TextLoader.Create(_mlContext, new TextLoader.Options()
+            {
+                Columns = new[]
+                    {
+                        new TextLoader.Column("ImagePath", DataKind.String, 0),
+                        new TextLoader.Column("Label", DataKind.String, 1),
+                    }
+            }, new MultiFileSource(dataFile));
+
+            var pipeEstimator = new ImageLoadingEstimator(_mlContext, imageFolder, ("ImageReal", "ImagePath"))
+                    .Append(new ImageResizingEstimator(_mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal"))
+                    .Append(new ImagePixelExtractingEstimator(_mlContext, "Input", "ImageCropped", interleavePixelColors: true))
+                    .Append(_mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input"))
+                    .Append(new ColumnConcatenatingEstimator(_mlContext, "Features", "Output"))
+                    .Append(new ValueToKeyMappingEstimator(_mlContext, "Label"))
+                    .AppendCacheCheckpoint(_mlContext)
+                    .Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes());
+
+
+            using var transformer = pipeEstimator.Fit(data);
+            var transformedData = transformer.Transform(data);
+            var outputSchema = transformer.GetOutputSchema(data.Schema);
+
+            var metrics = _mlContext.MulticlassClassification.Evaluate(transformedData);
+            Assert.Equal(1, metrics.MicroAccuracy, 2);
+
+            var predictFunction = _mlContext.Model.CreatePredictionEngine<CifarData, CifarPrediction>(transformer);
+            var predictions = new[]
+            {
+                predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }),
+                predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }),
+                predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") })
+            };
+
+            // Save the model as a standard ML.NET zip repo
+            var mlModelLocation = DeleteOutputPath(Path.ChangeExtension(modelLocation, ".zip"));
+            _mlContext.Model.Save(transformer, data.Schema, mlModelLocation);
+            transformer.Dispose();
+            predictFunction.Dispose();
+
+            // Reload the model and check the output schema consistency
+            DataViewSchema loadedInputschema;
+            var testTransformer = _mlContext.Model.Load(mlModelLocation, out loadedInputschema);
+            var testOutputSchema = transformer.GetOutputSchema(data.Schema);
+            Assert.True(TestCommon.CheckSameSchemas(outputSchema, testOutputSchema));
+
+            // Repeat the predictions with the model loaded as zip repo
+            var testPredictFunction = _mlContext.Model.CreatePredictionEngine<CifarData, CifarPrediction>(testTransformer);
+            var testPredictions = new[]
+            {
+                testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }),
+                testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }),
+                testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") })
+            };
+
+            // Check the predictions consistency
+            for (var i = 0; i < predictions.Length; i++) {
+                for (var j = 0; j < predictions[i].PredictedScores.Length; j++)
+                    Assert.Equal(predictions[i].PredictedScores[j], testPredictions[i].PredictedScores[j], 2);
+            }
+        }
+
         [TensorFlowFact]
         public void TensorFlowTransformCifarInvalidShape()
         {