From 7543dd040b2d32109a2718d7276d3aca1edadaae Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Jul 2024 10:50:13 +1200 Subject: [PATCH] Propagate NaNs in the CPU min and max operators (#21492) ### Description Propagates NaN values in the min and max operators so that min or max with a NaN in either input always produces NaN. ### Motivation and Context Fixes #21455 --- .../providers/cpu/math/element_wise_ops.cc | 18 +- onnxruntime/test/providers/checkers.cc | 2 +- .../cpu/math/element_wise_ops_test.cc | 188 ++++++++++++++++-- 3 files changed, 187 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 1d524a90302e7..5ea6000da1cba 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -705,7 +705,7 @@ Status Min_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - min = min.array().min(EigenMap(data_n).array()); + min = min.array().template min(EigenMap(data_n).array()); } return Status::OK(); @@ -721,15 +721,16 @@ struct Min_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().min(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template min(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template min(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template min( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); @@ -827,7 +828,7 @@ Status Max_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - max = max.array().max(EigenMap(data_n).array()); + max = max.array().template max(EigenMap(data_n).array()); } return Status::OK(); @@ -843,15 +844,16 @@ struct Max_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().max(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template max(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template max(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template max( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 5f332ddcddb8d..182fa4729a88f 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -427,7 +427,7 @@ struct TensorCheck { for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { - EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb3575f2cde88..bd3d21d4929f3 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1553,6 +1553,47 @@ TEST(MathOpTest, Min_12_Float_Nan) { } } +TEST(MathOpTest, Min_12_Float_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.25f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Float_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1586,12 +1627,53 @@ TEST(MathOpTest, Min_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -1.0, -1.0, -2.0, 0.5, 0.0, 1.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.25}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1666,7 +1748,7 @@ TEST(MathOpTest, Min_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16) { +TEST(MathOpTest, Min_12_MLFloat16) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({1.f, 1.f, 1.f})); @@ -1679,7 +1761,7 @@ TEST(MathOpTest, Min_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar0) { OpTester test("Min", 12); test.AddInput("data_0", {}, MakeMLFloat16({-10.f})); @@ -1692,7 +1774,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({2.f, 3.f, 4.f})); @@ -1809,12 +1891,53 @@ TEST(MathOpTest, Max_12_Float_Nan) { std::numeric_limits::quiet_NaN(), -0.5f, 0.0f, -1.0f, 1.0f, 1.0f, 2.0f}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25f, 0.5f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1854,12 +1977,53 @@ TEST(MathOpTest, Max_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -0.5, 0.0, -1.0, 1.0, 1.0, 2.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25, 0.5}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1934,7 +2098,7 @@ TEST(MathOpTest, Max_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16) { +TEST(MathOpTest, Max_12_MLFloat16) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -1.f, -1.f})); @@ -1947,7 +2111,7 @@ TEST(MathOpTest, Max_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar0) { OpTester test("Max", 12); test.AddInput("data_0", {}, MakeMLFloat16({-1.f})); @@ -1960,7 +2124,7 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -2.f, -3.f}));