Skip to content

Commit

Permalink
Propagate NaNs in the CPU min and max operators (#21492)
Browse files Browse the repository at this point in the history
### Description

Propagates NaN values in the min and max operators so that min or max
with a NaN in either input always produces NaN.

### Motivation and Context

Fixes #21455
  • Loading branch information
adamreeve authored Jul 29, 2024
1 parent c39f1c4 commit 7543dd0
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 21 deletions.
18 changes: 10 additions & 8 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ Status Min_6<float>::Compute(OpKernelContext* ctx) const {
for (int index = 1; index < inputCount; index++) {
auto& data_n = *ctx->Input<Tensor>(index);
ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape");
min = min.array().min(EigenMap<float>(data_n).array());
min = min.array().template min<Eigen::PropagateNaN>(EigenMap<float>(data_n).array());
}

return Status::OK();
Expand All @@ -721,15 +721,16 @@ struct Min_8::ComputeImpl {
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() =
per_iter_bh.EigenInput1<T>().array().min(per_iter_bh.ScalarInput0<T>());
per_iter_bh.EigenInput1<T>().array().template min<Eigen::PropagateNaN>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() =
per_iter_bh.EigenInput0<T>().array().min(per_iter_bh.ScalarInput1<T>());
per_iter_bh.EigenInput0<T>().array().template min<Eigen::PropagateNaN>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() =
per_iter_bh.EigenInput0<T>().array().min(per_iter_bh.EigenInput1<T>().array());
per_iter_bh.EigenInput0<T>().array().template min<Eigen::PropagateNaN>(
per_iter_bh.EigenInput1<T>().array());
}};

int input_count = inst.Node().InputArgCount().front();
Expand Down Expand Up @@ -827,7 +828,7 @@ Status Max_6<float>::Compute(OpKernelContext* ctx) const {
for (int index = 1; index < inputCount; index++) {
auto& data_n = *ctx->Input<Tensor>(index);
ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape");
max = max.array().max(EigenMap<float>(data_n).array());
max = max.array().template max<Eigen::PropagateNaN>(EigenMap<float>(data_n).array());
}

return Status::OK();
Expand All @@ -843,15 +844,16 @@ struct Max_8::ComputeImpl {
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() =
per_iter_bh.EigenInput1<T>().array().max(per_iter_bh.ScalarInput0<T>());
per_iter_bh.EigenInput1<T>().array().template max<Eigen::PropagateNaN>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() =
per_iter_bh.EigenInput0<T>().array().max(per_iter_bh.ScalarInput1<T>());
per_iter_bh.EigenInput0<T>().array().template max<Eigen::PropagateNaN>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() =
per_iter_bh.EigenInput0<T>().array().max(per_iter_bh.EigenInput1<T>().array());
per_iter_bh.EigenInput0<T>().array().template max<Eigen::PropagateNaN>(
per_iter_bh.EigenInput1<T>().array());
}};

int input_count = inst.Node().InputArgCount().front();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/checkers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ struct TensorCheck<MLFloat16> {

for (int64_t i = 0; i < size; ++i) {
if (std::isnan(f_expected[i])) {
EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i;
EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i;
} else if (std::isinf(f_expected[i])) { // Test infinity for equality
EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i;
} else {
Expand Down
188 changes: 176 additions & 12 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,47 @@ TEST(MathOpTest, Min_12_Float_Nan) {
}
}

TEST(MathOpTest, Min_12_Float_Nan_with_scalar) {
OpTester test("Min", 12);
test.AddInput<float>("data_1", {3, 1},
{std::numeric_limits<float>::quiet_NaN(), -0.5f, 0.5f});
test.AddInput<float>("data_2", {1}, {0.25f});
test.AddOutput<float>("min", {3, 1},
{std::numeric_limits<float>::quiet_NaN(), -0.5f, 0.25f});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Float_with_scalar_Nan) {
OpTester test("Min", 12);
test.AddInput<float>("data_1", {2, 2},
{0.25f, -0.25f, -0.5f, 0.5f});
test.AddInput<float>("data_2", {1}, {std::numeric_limits<float>::quiet_NaN()});
test.AddOutput<float>("min", {2, 2},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Double) {
OpTester test("Min", 12);
test.AddInput<double>("data_0", {1, 3},
Expand Down Expand Up @@ -1586,12 +1627,53 @@ TEST(MathOpTest, Min_12_Double_Nan) {
std::numeric_limits<double>::quiet_NaN(),
-1.0, -1.0, -2.0,
0.5, 0.0, 1.0});
if (nullptr != DefaultCpuExecutionProvider().get()) {
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Double_Nan_with_scalar) {
OpTester test("Min", 12);
test.AddInput<double>("data_1", {3, 1},
{std::numeric_limits<double>::quiet_NaN(), -0.5, 0.5});
test.AddInput<double>("data_2", {1}, {0.25});
test.AddOutput<double>("min", {3, 1},
{std::numeric_limits<double>::quiet_NaN(), -0.5, 0.25});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Double_with_scalar_Nan) {
OpTester test("Min", 12);
test.AddInput<double>("data_1", {2, 2},
{0.25, -0.25, -0.5, 0.5});
test.AddInput<double>("data_2", {1}, {std::numeric_limits<double>::quiet_NaN()});
test.AddOutput<double>("min", {2, 2},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
Expand Down Expand Up @@ -1666,7 +1748,7 @@ TEST(MathOpTest, Min_12_UInt64) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_MLFLoat16) {
TEST(MathOpTest, Min_12_MLFloat16) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {1, 3},
MakeMLFloat16({1.f, 1.f, 1.f}));
Expand All @@ -1679,7 +1761,7 @@ TEST(MathOpTest, Min_12_MLFLoat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) {
TEST(MathOpTest, Min_12_MLFloat16_Scalar0) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {},
MakeMLFloat16({-10.f}));
Expand All @@ -1692,7 +1774,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) {
TEST(MathOpTest, Min_12_MLFloat16_Scalar1) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {1, 3},
MakeMLFloat16({2.f, 3.f, 4.f}));
Expand Down Expand Up @@ -1809,12 +1891,53 @@ TEST(MathOpTest, Max_12_Float_Nan) {
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f});
if (nullptr != DefaultCpuExecutionProvider().get()) {
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Float_Nan_with_scalar) {
OpTester test("Max", 12);
test.AddInput<float>("data_1", {3, 1},
{std::numeric_limits<float>::quiet_NaN(), -0.5f, 0.5f});
test.AddInput<float>("data_2", {1}, {0.25f});
test.AddOutput<float>("max", {3, 1},
{std::numeric_limits<float>::quiet_NaN(), 0.25f, 0.5f});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Float_with_scalar_Nan) {
OpTester test("Max", 12);
test.AddInput<float>("data_1", {2, 2},
{0.25f, -0.25f, -0.5f, 0.5f});
test.AddInput<float>("data_2", {1}, {std::numeric_limits<float>::quiet_NaN()});
test.AddOutput<float>("max", {2, 2},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
Expand Down Expand Up @@ -1854,12 +1977,53 @@ TEST(MathOpTest, Max_12_Double_Nan) {
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -1.0,
1.0, 1.0, 2.0});
if (nullptr != DefaultCpuExecutionProvider().get()) {
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Double_Nan_with_scalar) {
OpTester test("Max", 12);
test.AddInput<double>("data_1", {3, 1},
{std::numeric_limits<double>::quiet_NaN(), -0.5, 0.5});
test.AddInput<double>("data_2", {1}, {0.25});
test.AddOutput<double>("max", {3, 1},
{std::numeric_limits<double>::quiet_NaN(), 0.25, 0.5});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Double_with_scalar_Nan) {
OpTester test("Max", 12);
test.AddInput<double>("data_1", {2, 2},
{0.25, -0.25, -0.5, 0.5});
test.AddInput<double>("data_2", {1}, {std::numeric_limits<double>::quiet_NaN()});
test.AddOutput<double>("max", {2, 2},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()});
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
Expand Down Expand Up @@ -1934,7 +2098,7 @@ TEST(MathOpTest, Max_12_UInt64) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_MLFLoat16) {
TEST(MathOpTest, Max_12_MLFloat16) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {1, 3},
MakeMLFloat16({-1.f, -1.f, -1.f}));
Expand All @@ -1947,7 +2111,7 @@ TEST(MathOpTest, Max_12_MLFLoat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) {
TEST(MathOpTest, Max_12_MLFloat16_Scalar0) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {},
MakeMLFloat16({-1.f}));
Expand All @@ -1960,7 +2124,7 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) {
TEST(MathOpTest, Max_12_MLFloat16_Scalar1) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {1, 3},
MakeMLFloat16({-1.f, -2.f, -3.f}));
Expand Down

0 comments on commit 7543dd0

Please sign in to comment.