Skip to content

Commit

Permalink
Merge pull request microsoft#24 from NonStatic2014/bohu/binary_input
Browse files Browse the repository at this point in the history
Support binary input and content type headers
  • Loading branch information
NonStatic2014 authored Apr 4, 2019
2 parents 05847b0 + 1a15081 commit 6e5f6f8
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 23 deletions.
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -536,14 +536,15 @@ if (onnxruntime_BUILD_HOSTING)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties("${TEST_SRC_DIR}/hosting/json_handling_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${TEST_SRC_DIR}/hosting/converter_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${TEST_SRC_DIR}/hosting/util_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()

find_package(Boost 1.69 COMPONENTS system context thread program_options REQUIRED)
add_library(onnxruntime_test_utils_for_hosting ${onnxruntime_test_hosting_src})
onnxruntime_add_include_to_target(onnxruntime_test_utils_for_hosting onnxruntime_test_utils gtest gmock gsl onnx onnx_proto hosting_proto)
add_dependencies(onnxruntime_test_utils_for_hosting onnxruntime_hosting ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_test_utils_for_hosting PUBLIC ${Boost_INCLUDE_DIR} ${REPO_ROOT}/cmake/external/re2 ${CMAKE_CURRENT_BINARY_DIR}/onnx PRIVATE ${ONNXRUNTIME_ROOT} )
target_include_directories(onnxruntime_test_utils_for_hosting PUBLIC ${Boost_INCLUDE_DIR} ${REPO_ROOT}/cmake/external/re2 ${CMAKE_CURRENT_BINARY_DIR}/onnx ${ONNXRUNTIME_ROOT}/hosting/http PRIVATE ${ONNXRUNTIME_ROOT} )
target_link_libraries(onnxruntime_test_utils_for_hosting ${Boost_LIBRARIES} ${onnx_test_libs})

AddTest(
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/hosting/http/json_handling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protobufutil::Status GenerateResponseInJson(const onnxruntime::hosting::PredictR
}

std::string CreateJsonError(const http::status error_code, const std::string& error_message) {
return "{\"error_code\": " + std::to_string(int(error_code)) + ", \"error_message\": " + error_message + " }";
return "{\"error_code\": " + std::to_string(int(error_code)) + ", \"error_message\": " + error_message + " }\n";
}

} // namespace hosting
Expand Down
110 changes: 89 additions & 21 deletions onnxruntime/hosting/http/predict_request_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@ namespace hosting {

namespace protobufutil = google::protobuf::util;

#define GenerateErrorResponse(logger, error_code, status, context) \
{ \
auto http_error_code = (error_code); \
auto error_message = CreateJsonError(http_error_code, (status).error_message()); \
LOGS((*logger), VERBOSE) << error_message; \
(context).response.result(http_error_code); \
(context).response.body() = error_message; \
(context).response.set(http::field::content_type, "application/json"); \
#define GenerateErrorResponse(logger, error_code, message, context, ms_request_id, client_request_id) \
{ \
auto http_error_code = (error_code); \
(context).response.insert("x-ms-request-id", (ms_request_id)); \
if (!client_request_id.empty()) { \
(context).response.insert("x-ms-client-request-id", (client_request_id)); \
} \
auto json_error_message = CreateJsonError(http_error_code, (message)); \
LOGS((*logger), VERBOSE) << json_error_message; \
(context).response.result(http_error_code); \
(context).response.body() = json_error_message; \
(context).response.set(http::field::content_type, "application/json"); \
}

// TODO: decide whether this should be a class
static bool ParseRequestPayload(const HttpContext& context, SupportedContentType request_type,
/* out */ PredictRequest& predictRequest, /* out */ http::status& error_code, /* out */ std::string& error_message);

void Predict(const std::string& name,
const std::string& version,
const std::string& action,
Expand All @@ -33,33 +39,95 @@ void Predict(const std::string& name,
auto logger = env->GetLogger(context.uuid);
LOGS(*logger, VERBOSE) << "Name: " << name << " Version: " << version << " Action: " << action;

// We need to persist the "x-ms-client-request-id" field for the user to track the request from client side
std::string client_request_id{};
if (context.request.find("x-ms-client-request-id") != context.request.end()) {
client_request_id = context.request["x-ms-client-request-id"].to_string();
LOGS(*logger, VERBOSE) << "x-ms-client-request-id: [" << client_request_id << "]";
}

// Request and Response content type information
SupportedContentType request_type = GetRequestContentType(context);
SupportedContentType response_type = GetResponseContentType(context);
if (response_type == SupportedContentType::Unknown) {
GenerateErrorResponse(logger, http::status::bad_request, "Unknown 'Accept' header field in the request", context, context.uuid, client_request_id);
}

// Deserialize the payload
auto body = context.request.body();
PredictRequest predictRequest{};
auto status = GetRequestFromJson(body, predictRequest);
if (!status.ok()) {
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status, context);
PredictRequest predict_request{};
http::status error_code;
std::string error_message;
bool parse_succeeded = ParseRequestPayload(context, request_type, predict_request, error_code, error_message);
if (!parse_succeeded) {
GenerateErrorResponse(logger, error_code, error_message, context, context.uuid, client_request_id);
return;
}

// Run Prediction
protobufutil::Status status;
Executor executor(env);
PredictResponse predictResponse{};
status = executor.Predict(name, version, "request_id", predictRequest, predictResponse);
PredictResponse predict_response{};
status = executor.Predict(name, version, context.uuid, predict_request, predict_response);
if (!status.ok()) {
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status, context);
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status.error_message(), context, context.uuid, client_request_id);
return;
}

// Serialize to proper output format
std::string response_body{};
status = GenerateResponseInJson(predictResponse, response_body);
if (!status.ok()) {
GenerateErrorResponse(logger, http::status::internal_server_error, status, context);
return;
if (response_type == SupportedContentType::Json) {
status = GenerateResponseInJson(predict_response, response_body);
if (!status.ok()) {
GenerateErrorResponse(logger, http::status::internal_server_error, status.error_message(), context, context.uuid, client_request_id);
return;
}
context.response.set(http::field::content_type, "application/json");
} else {
response_body = predict_response.SerializeAsString();
context.response.set(http::field::content_type, "application/octet-stream");
}

// Build HTTP response
context.response.insert("x-ms-request-id", context.uuid);
if (!client_request_id.empty()) {
context.response.insert("x-ms-client-request-id", client_request_id);
}
context.response.body() = response_body;
context.response.result(http::status::ok);
context.response.set(http::field::content_type, "application/json");
};

static bool ParseRequestPayload(const HttpContext& context, SupportedContentType request_type, PredictRequest& predictRequest, http::status& error_code, std::string& error_message) {
auto body = context.request.body();
protobufutil::Status status;
switch (request_type) {
case SupportedContentType::Json: {
status = GetRequestFromJson(body, predictRequest);
if (!status.ok()) {
error_code = GetHttpStatusCode(status);
error_message = status.error_message();
return false;
}
break;
}
case SupportedContentType::PbByteArray: {
bool parse_succeeded = predictRequest.ParseFromArray(body.data(), body.size());
if (!parse_succeeded) {
error_code = http::status::bad_request;
error_message = "Invalid payload.";
return false;
}
break;
}
default: {
error_code = http::status::bad_request;
error_message = "Missing or unknown 'Content-Type' header field in the request";
return false;
}
}

return true;
}

} // namespace hosting
} // namespace onnxruntime
33 changes: 33 additions & 0 deletions onnxruntime/hosting/http/util.cc
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <unordered_set>
#include <boost/beast/core.hpp>
#include <boost/beast/http/status.hpp>
#include <google/protobuf/stubs/status.h>

#include "context.h"
#include "util.h"

namespace protobufutil = google::protobuf::util;
namespace onnxruntime {
namespace hosting {

static std::unordered_set<std::string> protobuf_mime_types{
"application/octet-stream",
"application/vnd.google.protobuf",
"application/x-protobuf"};

// Report a failure
void ErrorHandling(beast::error_code ec, char const* what) {
std::cerr << what << ": " << ec.message() << "\n";
Expand Down Expand Up @@ -52,5 +59,31 @@ boost::beast::http::status GetHttpStatusCode(const protobufutil::Status& status)
}
}

SupportedContentType GetRequestContentType(const HttpContext& context) {
if (context.request.find("Content-Type") != context.request.end()) {
if (context.request["Content-Type"] == "application/json") {
return SupportedContentType::Json;
} else if (protobuf_mime_types.find(context.request["Content-Type"].to_string()) != protobuf_mime_types.end()) {
return SupportedContentType::PbByteArray;
}
}

return SupportedContentType::Unknown;
}

SupportedContentType GetResponseContentType(const HttpContext& context) {
if (context.request.find("Accept") != context.request.end()) {
if (context.request["Accept"] == "application/json") {
return SupportedContentType::Json;
} else if (context.request["Accept"] == "*/*" || protobuf_mime_types.find(context.request["Accept"].to_string()) != protobuf_mime_types.end()) {
return SupportedContentType::PbByteArray;
}
} else {
return SupportedContentType::PbByteArray;
}

return SupportedContentType::Unknown;
}

} // namespace hosting
} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/hosting/http/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,33 @@
#include <boost/beast/http/status.hpp>
#include <google/protobuf/stubs/status.h>

#include "context.h"

namespace onnxruntime {
namespace hosting {

namespace beast = boost::beast; // from <boost/beast.hpp>

enum class SupportedContentType : int {
Unknown,
Json,
PbByteArray
};

// Report a failure
void ErrorHandling(beast::error_code ec, char const* what);

// Mapping protobuf status to http status
boost::beast::http::status GetHttpStatusCode(const google::protobuf::util::Status& status);

// "Content-Type" header field in request is MUST-HAVE.
// Currently we only support two types of input content type: application/json and application/octet-stream
SupportedContentType GetRequestContentType(const HttpContext& context);

// "Accept" header field in request is OPTIONAL.
// Currently we only support three types of response content type: */*, application/json and application/octet-stream
SupportedContentType GetResponseContentType(const HttpContext& context);

} // namespace hosting
} // namespace onnxruntime

Expand Down
121 changes: 121 additions & 0 deletions onnxruntime/test/hosting/util_tests.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <google/protobuf/stubs/status.h>
#include "gtest/gtest.h"
#include "context.h"
#include "util.h"

namespace onnxruntime {
namespace hosting {
namespace test {

namespace protobufutil = google::protobuf::util;

TEST(PositiveTests, GetRequestContentTypeJson) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::content_type, "application/json");
context.request = request;

auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::Json);
}

TEST(PositiveTests, GetRequestContentTypeRawData) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::content_type, "application/octet-stream");
context.request = request;

auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);

context.request.set(http::field::content_type, "application/vnd.google.protobuf");
result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);

context.request.set(http::field::content_type, "application/x-protobuf");
result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}

TEST(NegativeTests, GetRequestContentTypeUnknown) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::content_type, "text/plain");
context.request = request;

auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::Unknown);
}

TEST(NegativeTests, GetRequestContentTypeMissing) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
context.request = request;

auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::Unknown);
}

TEST(PositiveTests, GetResponseContentTypeJson) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "application/json");
context.request = request;

auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::Json);
}

TEST(PositiveTests, GetResponseContentTypeRawData) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "application/octet-stream");
context.request = request;

auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);

context.request.set(http::field::accept, "application/vnd.google.protobuf");
result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);

context.request.set(http::field::accept, "application/x-protobuf");
result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}

TEST(NegativeTests, GetResponseContentTypeAny) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "*/*");
context.request = request;

auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}

TEST(NegativeTests, GetResponseContentTypeUnknown) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "text/plain");
context.request = request;

auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::Unknown);
}

TEST(NegativeTests, GetResponseContentTypeMissing) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
context.request = request;

auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}

} // namespace test
} // namespace hosting
} // namespace onnxruntime

0 comments on commit 6e5f6f8

Please sign in to comment.