diff --git a/test/Microsoft.ML.Functional.Tests/ONNX.cs b/test/Microsoft.ML.Functional.Tests/ONNX.cs index 88438305bd..2776a1299e 100644 --- a/test/Microsoft.ML.Functional.Tests/ONNX.cs +++ b/test/Microsoft.ML.Functional.Tests/ONNX.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.IO; +using Microsoft.ML.Data; using Microsoft.ML.Functional.Tests.Datasets; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; @@ -14,6 +15,12 @@ namespace Microsoft.ML.Functional.Tests { + internal sealed class OnnxScoreColumn + { + [ColumnName("Score0")] + public float[] Score { get; set; } + } + public class ONNX : BaseTestClass { public ONNX(ITestOutputHelper output) : base(output) @@ -51,15 +58,9 @@ public void SaveOnnxModelLoadAndScoreFastTree() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath); var onnxModel = onnxEstimator.Fit(data); - // TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now. - // TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this. - var onnxWorkaroundPipeline = onnxModel.Append( - mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data))); - // Create prediction engine and test predictions. var originalPredictionEngine = mlContext.Model.CreatePredictionEngine(model); - // TODO #2982: ONNX produces vector types and not the original output type. - var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine(onnxWorkaroundPipeline); + var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine(onnxModel); // Take a handful of examples out of the dataset and compute predictions. var dataEnumerator = mlContext.Data.CreateEnumerable(mlContext.Data.TakeRows(data, 5), false);