Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: Saving problem with TF2 SavedModel fmt in TensorflowTransform class. #5797

Merged
merged 3 commits into from
Jun 4, 2021

Conversation

darth-vader-lg
Copy link
Contributor

The SaveModel function of the TensorflowTransform class didn't save the TensorFlow saved_model directory in the zip repo. It was just done for frozen graphs but missing for the SavedModel format.

I followed the schema that you used for the DnnRetrainTransform class to fix it:

internal static class DefaultModelFileNames
{
public const string VariablesFolder = "variables";
public const string Index = "variables.index";
public const string Data = "variables.data-00000-of-00001";
public const string Graph = "saved_model.pb";
public const string TmpMlnetModel = "mlnet_model";
}

and
ctx.SaveBinaryStream("TFSavedModel", w =>
{
// only these files need to be saved.
string[] modelFilePaths =
{
Path.Combine(_modelLocation, DefaultModelFileNames.Graph),
Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Data),
Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index),
};
w.Write(modelFilePaths.Length);
foreach (var fullPath in modelFilePaths)
{
var relativePath = fullPath.Substring(_modelLocation.Length + 1);
w.Write(relativePath);
using (var fs = new FileStream(fullPath, FileMode.Open))
{
long fileLength = fs.Length;
w.Write(fileLength);
long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
Host.Assert(actualWritten == fileLength);
}
}
});
.

The same part was not present in the TensorflowTransform class. Just the frozen graph saving:

private protected override void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// byte: indicator for frozen models
// byte: indicator for adding batch dimension in input
// byte: indicator for treating output as batched
// stream: tensorFlow model.
// int: number of input columns
// for each input column
// int: id of int column name
// int: number of output columns
// for each output column
// int: id of output column name
var isFrozen = string.IsNullOrEmpty(_savedModelPath);
ctx.Writer.WriteBoolByte(isFrozen);
ctx.Writer.WriteBoolByte(_addBatchDimensionInput);
ctx.Writer.WriteBoolByte(_treatOutputAsBatched);
if (isFrozen)
{
using (var status = new Status())
using (var buffer = Session.graph.ToGraphDef(status))
{
ctx.SaveBinaryStream("TFModel", w =>
{
w.WriteByteArray(buffer.DangerousMemoryBlock.ToArray());
});
}
}
Host.AssertNonEmpty(Inputs);
ctx.Writer.Write(Inputs.Length);
foreach (var colName in Inputs)
ctx.SaveNonEmptyString(colName);
Host.AssertNonEmpty(Outputs);
ctx.Writer.Write(Outputs.Length);
foreach (var colName in Outputs)
ctx.SaveNonEmptyString(colName);
}
.

It leads to an incomplete zip repo that cannot be reloaded after.

After this fix the zip repo can be saved and loaded for inference.
Inference2
saved_model.pb.zip

…'t save the TensorFlow SavedModel directory in the repo. It was just done for frozen graphs. It was missing for the SavedModel format.
@dnfadmin
Copy link

dnfadmin commented May 17, 2021

CLA assistant check
All CLA requirements met.

@codecov
Copy link

codecov bot commented May 17, 2021

Codecov Report

Merging #5797 (ce683fc) into main (43c49f6) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff            @@
##             main    #5797    +/-   ##
========================================
  Coverage   68.35%   68.36%            
========================================
  Files        1131     1131            
  Lines      241210   241372   +162     
  Branches    25039    25055    +16     
========================================
+ Hits       164887   165011   +124     
- Misses      69819    69857    +38     
  Partials     6504     6504            
Flag Coverage Δ
Debug 68.36% <100.00%> (+<0.01%) ⬆️
production 62.97% <100.00%> (+<0.01%) ⬆️
test 89.25% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/Microsoft.ML.TensorFlow/TensorflowTransform.cs 84.70% <100.00%> (+5.22%) ⬆️
...cenariosWithDirectInstantiation/TensorflowTests.cs 92.35% <100.00%> (+0.29%) ⬆️
...c/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs 79.48% <0.00%> (-20.52%) ⬇️
src/Microsoft.ML.Core/Data/ProgressReporter.cs 70.95% <0.00%> (-6.99%) ⬇️
src/Microsoft.ML.FastTree/FastTreeRanking.cs 50.79% <0.00%> (-4.28%) ⬇️
src/Microsoft.ML.Data/MLContext.cs 90.47% <0.00%> (-2.03%) ⬇️
src/Microsoft.ML.FastTree/Dataset/IntArray.cs 12.10% <0.00%> (-0.11%) ⬇️
...oft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs 94.33% <0.00%> (-0.11%) ⬇️
src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs 86.63% <0.00%> (-0.06%) ⬇️
src/Microsoft.ML.FastTree/FastTree.cs 80.16% <0.00%> (-0.06%) ⬇️
... and 17 more

@michaelgsharp
Copy link
Member

@darth-vader-lg this looks great! Thanks for taking the time to submit this.

Can you add a unit test around this new saving/loading?

@darth-vader-lg
Copy link
Contributor Author

Hello @michaelgsharp,
I will add it ASAP.
Sorry if I didn't do it before. I tried but I had some troubles running the unit test system on my pc.
I had, as usual, to quick deliver to my customer, that was pushing me, the robotic vision application where I'm working on and... bye bye 🙆‍♂️.
But in any case it luckily works on the field in battle. 😉
I will prepare for You while working on integration of TensorFlow2.5.0 with ML.NET.

…).

Signed-off-by: darth-vader-lg <luigi.generale@gmail.com>
@darth-vader-lg
Copy link
Contributor Author

Hello @michaelgsharp,

The unit test is added: TensorFlowSaveAndLoadSavedModel.

It does the following steps:

  • Load your standard test model cifar_saved_model and creates the transformer.
  • Do some predictions.
  • Save the transformer as ML.NET zip repo.
  • Reload the transformer from the saved zip repo.
  • Do the same above list of predictions.
  • Compare the results for equality.

Check if it's all ok for You and have a good merge.

Copy link
Member

@michaelgsharp michaelgsharp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@michaelgsharp michaelgsharp merged commit 992d989 into dotnet:main Jun 4, 2021
@darth-vader-lg darth-vader-lg deleted the bugfix-tf-saved-model branch June 26, 2021 08:28
@ghost ghost locked as resolved and limited conversation to collaborators Mar 17, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants