diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index fc8f7c5917873..a688241d71867 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -76,8 +76,6 @@ void GenerateData(std::default_random_engine& generator, TDistribution distribut static Status RandomNormalCompute(float mean, float scale, std::default_random_engine& generator, TensorProto::DataType dtype, Tensor& Y); static Status RandomUniformCompute(float high, float low, std::default_random_engine& generator, TensorProto::DataType dtype, Tensor& Y); -// Leaving in case we need to change to this approach -//static Status CreateOutputTensorFromTensorValues(OpKernelContext* ctx, const Tensor& X,Tensor** Y); static Status CreateOutputTensorFromTensorShape(OpKernelContext* ctx, const Tensor& X, Tensor** Y); static TensorProto::DataType InferDataType(const Tensor& tensor); @@ -168,53 +166,48 @@ static Status MultinomialCompute(OpKernelContext* ctx, Eigen::array Y_dims = {{batch_size, num_samples}}; Matrix output = Matrix(Y.template MutableData(), Y_dims); - // TODO (perf optimization) - the idea behind making this a lambda is so that we can parallelize across batches. - // When we do that this lamdba will act as one task given to a thread - auto DoWork = [ctx, num_samples, num_classes, &generator, &logits, &output](int64_t start_row, - int64_t limit_row) { - std::default_random_engine generator_copy = generator; - // BEGIN create temporary tensor - AllocatorPtr alloc; - ctx->GetTempSpaceAllocator(&alloc); - auto cdf_data = static_cast(alloc->Alloc(sizeof(double) * num_classes)); - BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(alloc)); - Eigen::array cdf_dims = {{num_classes}}; - auto cdf = EigenVector(cdf_data, cdf_dims); - // END create temporary tensor - - std::uniform_real_distribution dist(0.0, 1.0); // TODO: should this be initialized per batch? - for (int64_t b = start_row; b < limit_row; ++b) { - const float* logits_row = &(logits(b, 0)); - // Takes an along-class maximum (for numerical stability). - float maxx = std::numeric_limits::lowest(); - for (int64_t j = 0; j < num_classes; ++j) { - if (Eigen::numext::isfinite(logits_row[j])) { - maxx = std::max(maxx, logits_row[j]); - } + // BEGIN create temporary tensor + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); + auto cdf_data = static_cast(alloc->Alloc(sizeof(double) * num_classes)); + BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(alloc)); + Eigen::array cdf_dims = {{num_classes}}; + auto cdf = EigenVector(cdf_data, cdf_dims); + // END create temporary tensor + + std::uniform_real_distribution dist(0.0, 1.0); // TODO: should this be initialized per batch? + + for (int64_t b = 0; b < batch_size; ++b) { + const float* logits_row = &(logits(b, 0)); + // Takes an along-class maximum (for numerical stability). + float maxx = std::numeric_limits::lowest(); + for (int64_t j = 0; j < num_classes; ++j) { + if (Eigen::numext::isfinite(logits_row[j])) { + maxx = std::max(maxx, logits_row[j]); } - const auto max_logit = static_cast(maxx); - - // Precompute cumulative probability distribution across classes. - // Note: This isn't normalized. - cdf = (logits.chip<0>(b).cast() - max_logit).exp(); - double running_total = 0; - for (int64_t j = 0; j < num_classes; ++j) { - if (Eigen::numext::isfinite(logits_row[j])) { - running_total += cdf(j); - } - cdf(j) = running_total; - } - // Generate each sample. - const double* cdf_begin = cdf.data(); - const double* cdf_end = cdf.data() + num_classes; - for (int64_t j = 0; j < num_samples; ++j) { - const double to_find = dist(generator_copy) * running_total; - auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find); - output(b, j) = static_cast(std::distance(cdf_begin, found_iter)); + } + const auto max_logit = static_cast(maxx); + + // Precompute cumulative probability distribution across classes. + // Note: This isn't normalized. + cdf = (logits.chip<0>(b).cast() - max_logit).exp(); + double running_total = 0; + for (int64_t j = 0; j < num_classes; ++j) { + if (Eigen::numext::isfinite(logits_row[j])) { + running_total += cdf(j); } + cdf(j) = running_total; + } + // Generate each sample. + const double* cdf_begin = cdf.data(); + const double* cdf_end = cdf.data() + num_classes; + for (int64_t j = 0; j < num_samples; ++j) { + const double to_find = dist(generator) * running_total; + auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find); + output(b, j) = static_cast(std::distance(cdf_begin, found_iter)); } - }; - DoWork(0, batch_size); + } + return Status::OK(); } @@ -262,32 +255,6 @@ Status Multinomial::Compute(OpKernelContext* ctx) const { return status; } -/* -alternative interpretation of the spec is that the input tensor contains the dimensions as ints. -Keeping this temporarily in case we go back to that. - -// read shape information from input tensor and create output tensor with it -static Status CreateOutputTensorFromTensorValues(OpKernelContext* ctx, const Tensor& X, Tensor** Y) { - const TensorShape& shape = X.Shape(); - auto size = shape.Size(); - auto num_dims = shape.NumDimensions(); - - if (num_dims != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Expected 1 dimension tensor with shape information. Dimensions=", num_dims); - } - - std::vector dims; - dims.reserve(shape.Size()); - - auto data = gsl::make_span(tensor.template Data(), shape.Size()); - dims.insert(dims.cbegin(), data.cbegin(), data.cend()); - - *Y = ctx->Output(0, TensorShape(dims)); - - return Status::OK(); -} -*/ - // create output tensor using shape of input tensor static Status CreateOutputTensorFromTensorShape(OpKernelContext* ctx, const Tensor& X, Tensor** Y) { const TensorShape& shape = X.Shape(); @@ -363,9 +330,11 @@ static Status RandomUniformCompute(float low, float high, template void GenerateData(std::default_random_engine& generator, TDistribution distribution, Tensor& tensor) { - auto out = gsl::make_span(tensor.template MutableData(), tensor.Shape().Size()); - - std::for_each(out.begin(), out.end(), [&generator, &distribution](T& value) { value = distribution(generator); }); + T* out = tensor.MutableData(); + for (int64_t i = 0, end = tensor.Shape().Size(); i < end; ++i) { + *out = distribution(generator); + ++out; + } } } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index 9e5a804054779..ca5864b97d6c2 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -246,7 +246,7 @@ TEST(Random, MultinomialGoodCase) { const std::vector output_dims{batch_size, num_samples}; #ifdef _WIN32 const std::vector expected_output{2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 0}; -#elif defined(__MACH__) || defined (__ANDROID__) +#elif defined(__MACH__) || defined(__ANDROID__) const std::vector expected_output{1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 1, 1, 2, 0, 2, 2, 0, 2, 1, 1}; #else const std::vector expected_output{2, 0, 0, 1, 0, 1, 2, 0, 1, 0, 0, 1, 1, 0, 1, 0, 2, 0, 2, 0}; @@ -257,31 +257,46 @@ TEST(Random, MultinomialGoodCase) { } TEST(Random, MultinomialDefaultDType) { - OpTester test("Multinomial"); + auto run_test = [](int num_run_calls, const std::vector& expected_output) { + OpTester test("Multinomial"); + const int64_t num_samples = 10; + const int batch_size = 2; + const float seed = 1618.f; + + const std::vector input_dims{2, 3}; + std::vector input(TensorShape(input_dims).Size()); + std::fill(input.begin(), input.end(), -10.f); + test.AddInput("X", input_dims, input); + + test.AddAttribute("sample_size", num_samples); + test.AddAttribute("seed", seed); - const int64_t num_samples = 10; - const int batch_size = 2; - const float seed = 1618.f; + const std::vector output_dims{batch_size, num_samples}; + test.AddOutput("Y", output_dims, expected_output); - const std::vector input_dims{2, 3}; - std::vector input(TensorShape(input_dims).Size()); - std::fill(input.begin(), input.end(), -10.f); - test.AddInput("X", input_dims, input); + // test.Run() re-loads the model each time, so we need to do multiple calls to InferenceSession::Run inside of it + // to test that the second call to Compute produces different data + test.SetNumRunCalls(num_run_calls); - test.AddAttribute("sample_size", num_samples); - test.AddAttribute("seed", seed); + test.Run(); + }; - const std::vector output_dims{batch_size, num_samples}; #ifdef _WIN32 - const std::vector expected_output{2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 0}; -#elif defined(__MACH__) || defined (__ANDROID__) - const std::vector expected_output{1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 1, 1, 2, 0, 2, 2, 0, 2, 1, 1}; + const std::vector expected_output_1{2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 0}; + const std::vector expected_output_2{0, 0, 1, 0, 2, 2, 2, 0, 2, 1, 2, 1, 0, 2, 0, 2, 2, 1, 2, 1}; +#elif defined(__MACH__) || defined(__ANDROID__) + const std::vector expected_output_1{1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 1, 1, 2, 0, 2, 2, 0, 2, 1, 1}; + const std::vector expected_output_2{1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 2, 0, 1, 1, 0, 2, 2, 2, 1}; #else - const std::vector expected_output{2, 0, 0, 1, 0, 1, 2, 0, 1, 0, 0, 1, 1, 0, 1, 0, 2, 0, 2, 0}; + const std::vector expected_output_1{2, 0, 0, 1, 0, 1, 2, 0, 1, 0, 0, 1, 1, 0, 1, 0, 2, 0, 2, 0}; + const std::vector expected_output_2{2, 2, 1, 1, 0, 2, 2, 1, 1, 2, 0, 0, 0, 2, 0, 1, 1, 1, 0, 0}; #endif - test.AddOutput("Y", output_dims, expected_output); - test.Run(); + // Test output from a single call to Multinomial::Compute + run_test(1, expected_output_1); + + // Test output from 2 calls to Multinomial::Compute + run_test(2, expected_output_2); } TEST(Random, MultinomialInvalidDtype) { diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 4abbc94b827a9..679d90a953806 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -30,7 +30,7 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_tensor, con auto size = output_tensor.Shape().Size(); for (int i = 0; i < size; ++i) { - EXPECT_EQ(expected[i], output[i]) << "provider_type: " << provider_type; + EXPECT_EQ(expected[i], output[i]) << "i:" << i << ", provider_type: " << provider_type; } } @@ -51,19 +51,21 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_ten for (int i = 0; i < size; ++i) { if (std::isinf(expected[i])) { // Test infinity for equality - EXPECT_EQ(expected[i], output[i]); + EXPECT_EQ(expected[i], output[i]) << "i:" << i; } else if (std::isnan(expected[i])) { EXPECT_TRUE(std::isnan(output[i])) << "Expected output " << i << " to be NaN"; } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests - EXPECT_NEAR(expected[i], output[i], threshold) << "provider_type: " << provider_type; + EXPECT_NEAR(expected[i], output[i], threshold) << "i:" << i << ", provider_type: " << provider_type; } else { if (has_abs_err) { - EXPECT_NEAR(expected[i], output[i], expected_data.absolute_error_.value()) << "provider_type: " << provider_type; + EXPECT_NEAR(expected[i], output[i], expected_data.absolute_error_.value()) + << "i:" << i << ", provider_type: " << provider_type; } if (has_rel_err) { - EXPECT_NEAR(expected[i], output[i], expected_data.relative_error_.value() * std::abs(expected[i])) << "provider_type: " << provider_type; + EXPECT_NEAR(expected[i], output[i], expected_data.relative_error_.value() * std::abs(expected[i])) + << "i:" << i << ", provider_type: " << provider_type; } } } @@ -87,19 +89,21 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_tens for (int i = 0; i < size; ++i) { if (std::isinf(expected[i])) { // Test infinity for equality - EXPECT_EQ(expected[i], output[i]); + EXPECT_EQ(expected[i], output[i]) << "i:" << i; } else if (std::isnan(expected[i])) { EXPECT_TRUE(std::isnan(output[i])) << "Expected output " << i << " to be NaN"; } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests - EXPECT_NEAR(expected[i], output[i], threshold) << "provider_type: " << provider_type; + EXPECT_NEAR(expected[i], output[i], threshold) << "i:" << i << ", provider_type: " << provider_type; } else { if (has_abs_err) { - EXPECT_NEAR(expected[i], output[i], expected_data.absolute_error_.value()) << "provider_type: " << provider_type; + EXPECT_NEAR(expected[i], output[i], expected_data.absolute_error_.value()) + << "i:" << i << ", provider_type: " << provider_type; } if (has_rel_err) { - EXPECT_NEAR(expected[i], output[i], expected_data.relative_error_.value() * std::abs(expected[i])) << "provider_type: " << provider_type; + EXPECT_NEAR(expected[i], output[i], expected_data.relative_error_.value() * std::abs(expected[i])) + << "i:" << i << ", provider_type: " << provider_type; } } } @@ -121,10 +125,10 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_ float threshold = 0.001f; for (int i = 0; i < size; ++i) { if (std::isinf(f_expected[i])) // Test infinity for equality - EXPECT_EQ(f_expected[i], f_output[i]); + EXPECT_EQ(f_expected[i], f_output[i]) << "i:" << i; else { // the default for existing tests - EXPECT_NEAR(f_expected[i], f_output[i], threshold) << "provider_type: " << provider_type; + EXPECT_NEAR(f_expected[i], f_output[i], threshold) << "i:" << i << ", provider_type: " << provider_type; } } } @@ -342,23 +346,27 @@ void OpTester::ExecuteModel(Model& model, InferenceSession& session_object, Expe default_run_options.run_log_verbosity_level = 1; std::vector fetches; - status = session_object.Run(run_options ? *run_options : default_run_options, feeds, output_names, &fetches); - if (status.IsOK()) { - EXPECT_TRUE(expect_result == ExpectResult::kExpectSuccess) << "Expected failure but Run was successful"; - if (expect_result == ExpectResult::kExpectFailure) { - return; - } - } else { - if (expect_result == ExpectResult::kExpectFailure) { - // Disable expected_failure_string checks for MKL-DNN and nGraph EP's - if (provider_type != kMklDnnExecutionProvider && provider_type != kNGraphExecutionProvider) { - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr(expected_failure_string)); + for (int i = 0; i < num_run_calls_; ++i) { + fetches.clear(); + status = session_object.Run(run_options ? *run_options : default_run_options, feeds, output_names, &fetches); + + if (status.IsOK()) { + EXPECT_TRUE(expect_result == ExpectResult::kExpectSuccess) << "Expected failure but Run was successful"; + if (expect_result == ExpectResult::kExpectFailure) { + return; } } else { - LOGS_DEFAULT(ERROR) << "Run failed with status: " << status.ErrorMessage(); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + if (expect_result == ExpectResult::kExpectFailure) { + // Disable expected_failure_string checks for MKL-DNN and nGraph EP's + if (provider_type != kMklDnnExecutionProvider && provider_type != kNGraphExecutionProvider) { + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr(expected_failure_string)); + } + } else { + LOGS_DEFAULT(ERROR) << "Run failed with status: " << status.ErrorMessage(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + } + return; } - return; } // Verify the outputs @@ -515,7 +523,9 @@ void OpTester::Run(ExpectResult expect_result, //if node is not registered for the provider, skip node.SetExecutionProviderType(provider_type); - if (provider_type == onnxruntime::kNGraphExecutionProvider || provider_type == onnxruntime::kTensorrtExecutionProvider || provider_type == onnxruntime::kOpenVINOExecutionProvider) + if (provider_type == onnxruntime::kNGraphExecutionProvider || + provider_type == onnxruntime::kTensorrtExecutionProvider || + provider_type == onnxruntime::kOpenVINOExecutionProvider) continue; auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci = reg->TryFindKernel(node, execution_provider->Type()); diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index 5f93eabbcc714..94a2d8e5544f5 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -227,6 +227,13 @@ class OpTester { void SetOutputAbsErr(const char* name, float v); void SetOutputRelErr(const char* name, float v); + // Number of times to call InferenceSession::Run. The same feeds are used each time. + // e.g. used to verify the generator ops behave as expected + void SetNumRunCalls(int n) { + ORT_ENFORCE(n > 0); + num_run_calls_ = n; + } + template void AddAttribute(std::string name, T value) { // Generate a the proper AddAttribute call for later @@ -318,6 +325,7 @@ class OpTester { int opset_version_; bool add_shape_to_tensor_data_ = true; int add_symbolic_dim_to_tensor_data_ = -1; + int num_run_calls_ = 1; std::vector input_data_; std::vector output_data_; std::vector initializer_index_;