Skip to content

Commit

Permalink
The test pipeline for consuming an ONNX model would fail due to the
Browse files Browse the repository at this point in the history
Score column being named "Score0". The ONNX model will rename the output
columns by design, therefore a different class with the ColumnName of
"Score0" is needed. This fixes the test pipeline to address this issue.

Fixes dotnet#2981
  • Loading branch information
singlis committed Apr 9, 2019
1 parent 1724da8 commit 925bdf0
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions test/Microsoft.ML.Functional.Tests/ONNX.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Functional.Tests.Datasets;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
Expand All @@ -14,6 +15,12 @@

namespace Microsoft.ML.Functional.Tests
{
internal sealed class OnnxScoreColumn
{
[ColumnName("Score0")]
public float[] Score { get; set; }
}

public class ONNX : BaseTestClass
{
public ONNX(ITestOutputHelper output) : base(output)
Expand Down Expand Up @@ -51,15 +58,9 @@ public void SaveOnnxModelLoadAndScoreFastTree()
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath);
var onnxModel = onnxEstimator.Fit(data);

// TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now.
// TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this.
var onnxWorkaroundPipeline = onnxModel.Append(
mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data)));

// Create prediction engine and test predictions.
var originalPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(model);
// TODO #2982: ONNX produces vector types and not the original output type.
var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, VectorScoreColumn>(onnxWorkaroundPipeline);
var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, OnnxScoreColumn>(onnxModel);

// Take a handful of examples out of the dataset and compute predictions.
var dataEnumerator = mlContext.Data.CreateEnumerable<HousingRegression>(mlContext.Data.TakeRows(data, 5), false);
Expand Down

0 comments on commit 925bdf0

Please sign in to comment.