Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #54 Added --request-payload-limit option for large request support #55

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ sudo cmake --install build --prefix /usr/local/onnxruntime-server

# Install via a package manager

| OS | Method | Command |
|-------------------------------|------------|--------------------------------------------|
| Arch Linux | AUR | `yay -S onnxruntime-server` |
| OS | Method | Command |
|------------|--------|-----------------------------|
| Arch Linux | AUR | `yay -S onnxruntime-server` |

----

Expand Down Expand Up @@ -127,11 +127,12 @@ sudo cmake --install build --prefix /usr/local/onnxruntime-server

## Options

| Option | Environment | Description |
|-------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--workers` | `ONNX_SERVER_WORKERS` | Worker thread pool size.<br/>Default: `4` |
| `--model-dir` | `ONNX_SERVER_MODEL_DIR` | Model directory path<br/>The onnx model files must be located in the following path:<br/>`${model_dir}/${model_name}/${model_version}/model.onnx`<br/>Default: `models` |
| `--prepare-model` | `ONNX_SERVER_PREPARE_MODEL` | Pre-create some model sessions at server startup.<br/><br/>Format as a space-separated list of `model_name:model_version` or `model_name:model_version(session_options, ...)`.<br/><br/>Available session_options are<br/>- cuda=device_id`[ or true or false]`<br/><br/>eg) `model1:v1 model2:v9`<br/>`model1:v1(cuda=true) model2:v9(cuda=1)` |
| Option | Environment | Description |
|---------------------------|-------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--workers` | `ONNX_SERVER_WORKERS` | Worker thread pool size.<br/>Default: `4` |
| `--request-payload-limit` | `ONNX_SERVER_REQUEST_PAYLOAD_LIMIT` | HTTP/HTTPS request payload size limit.<br />Default: 1024 * 1024 * 10(10MB)` |
| `--model-dir` | `ONNX_SERVER_MODEL_DIR` | Model directory path<br/>The onnx model files must be located in the following path:<br/>`${model_dir}/${model_name}/${model_version}/model.onnx`<br/>Default: `models` |
| `--prepare-model` | `ONNX_SERVER_PREPARE_MODEL` | Pre-create some model sessions at server startup.<br/><br/>Format as a space-separated list of `model_name:model_version` or `model_name:model_version(session_options, ...)`.<br/><br/>Available session_options are<br/>- cuda=device_id`[ or true or false]`<br/><br/>eg) `model1:v1 model2:v9`<br/>`model1:v1(cuda=true) model2:v9(cuda=1)` |

### Backend options

Expand Down
5 changes: 4 additions & 1 deletion src/onnxruntime_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ namespace onnxruntime_server {
std::string model_dir;
std::string prepare_model;
model_bin_getter_t model_bin_getter{};
long request_payload_limit = 1024 * 1024 * 10;
};

namespace transport {
Expand All @@ -322,6 +323,7 @@ namespace onnxruntime_server {
asio::socket socket;
asio::acceptor acceptor;
uint_least16_t assigned_port = 0;
long request_payload_limit_;

onnx::session_manager *onnx_session_manager;

Expand All @@ -331,12 +333,13 @@ namespace onnxruntime_server {
public:
server(
boost::asio::io_context &io_context, onnx::session_manager *onnx_session_manager,
builtin_thread_pool *worker_pool, int port
builtin_thread_pool *worker_pool, int port, long request_payload_limit
);
~server();

builtin_thread_pool *get_worker_pool();
onnx::session_manager *get_onnx_session_manager();
[[nodiscard]] long request_payload_limit() const;
[[nodiscard]] uint_least16_t port() const;
};

Expand Down
9 changes: 8 additions & 1 deletion src/standalone/standalone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ int onnxruntime_server::standalone::init_config(int argc, char **argv) {
"workers", po::value<int>()->default_value(4),
"env: ONNX_SERVER_WORKERS\nWorker thread pool size.\nDefault: 4"
);
po_desc.add_options()(
"request-payload-limit", po::value<int>()->default_value(1024 * 1024 * 10),
"env: ONNX_SERVER_REQUEST_PAYLOAD_LIMIT\nHTTP/HTTPS request payload size limit.\nDefault: 1024 * 1024 * 10(10MB)"
);
po_desc.add_options()(
"model-dir", po::value<std::string>()->default_value("models"),
"env: ONNX_SERVER_MODEL_DIR\nModel directory path.\nThe onnx model files must be located in the "
Expand Down Expand Up @@ -156,7 +160,10 @@ int onnxruntime_server::standalone::init_config(int argc, char **argv) {
AixLog::Log::init({log_file, log_access_file});

if (vm.count("workers"))
config.num_threads = vm["workers"].as<int>();
config.num_threads = vm["workers"].as<long>();

if (vm.count("request-payload-limit"))
config.request_payload_limit = vm["request-payload-limit"].as<long>();

if (vm.count("model-dir"))
config.model_dir = vm["model-dir"].as<std::string>();
Expand Down
132 changes: 131 additions & 1 deletion src/test/e2e/e2e_test_http_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,137 @@ TEST(test_onnxruntime_server_http, HttpServerTest) {
ASSERT_GT(res_json["output"][0], 0);
}

{ // API: Execute session large request
auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
int size = 1000000;
for (int i = 0; i < size; i++) {
input["x"].push_back(input["x"][0]);
input["y"].push_back(input["y"][0]);
input["z"].push_back(input["z"][0]);
}
std::cout << input.dump().length() << " bytes\n";

bool exception = false;
try {
TIME_MEASURE_START
auto res =
http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
TIME_MEASURE_STOP
} catch (std::exception &e) {
exception = true;
std::cout << e.what() << std::endl;
}
ASSERT_TRUE(exception);
}

{ // API: Destroy session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), "");
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Destroy sessions\n" << res_json.dump(2) << "\n";
ASSERT_TRUE(res_json);
}

{ // API: List session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), "");
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json.size(), 0);
}

running = false;
server_thread.join();
}

TEST(test_onnxruntime_server_http, HttpServerLargeRequestTest) {
Orts::config config;
config.http_port = 0;
config.model_bin_getter = test_model_bin_getter;
config.request_payload_limit = 1024 * 1024 * 1024;

boost::asio::io_context io_context;
Orts::onnx::session_manager manager(config.model_bin_getter);
Orts::builtin_thread_pool worker_pool(config.num_threads);
Orts::transport::http::http_server server(io_context, config, &manager, &worker_pool);

bool running = true;
std::thread server_thread([&io_context, &running]() { test_server_run(io_context, &running); });

TIME_MEASURE_INIT

{ // API: Create session
json body = json::parse(R"({"model":"sample","version":"1"})");
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::post, "/api/sessions", server.port(), body.dump());
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Create session\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json["model"], "sample");
ASSERT_EQ(res_json["version"], "1");
}

{ // API: Get session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::get, "/api/sessions/sample/1", server.port(), "");
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Get session\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json["model"], "sample");
ASSERT_EQ(res_json["version"], "1");
}

{ // API: List session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), "");
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json.size(), 1);
ASSERT_EQ(res_json[0]["model"], "sample");
ASSERT_EQ(res_json[0]["version"], "1");
}

{ // API: Execute session
auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Execute sessions\n" << res_json.dump(2) << "\n";
ASSERT_TRUE(res_json.contains("output"));
ASSERT_EQ(res_json["output"].size(), 1);
ASSERT_GT(res_json["output"][0], 0);
}

{ // API: Execute session large request
auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
int size = 1000000;
for (int i = 0; i < size; i++) {
input["x"].push_back(input["x"][0]);
input["y"].push_back(input["y"][0]);
input["z"].push_back(input["z"][0]);
}
std::cout << input.dump().length() << " bytes\n";

TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
ASSERT_TRUE(res_json.contains("output"));
ASSERT_EQ(res_json["output"].size(), size + 1);
ASSERT_GT(res_json["output"][0], 0);
}

{ // API: Destroy session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), "");
Expand Down Expand Up @@ -132,6 +263,5 @@ http_request(beast::http::verb method, const std::string &target, short port, st

beast::http::read(socket, buffer, res);

socket.close();
return res;
}
116 changes: 114 additions & 2 deletions src/test/e2e/e2e_test_https_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,30 @@ TEST(test_onnxruntime_server_http, HttpsServerTest) {
ASSERT_GT(res_json["output"][0], 0);
}

{ // API: Execute session large request
auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
int size = 1000000;
for (int i = 0; i < size; i++) {
input["x"].push_back(input["x"][0]);
input["y"].push_back(input["y"][0]);
input["z"].push_back(input["z"][0]);
}
std::cout << input.dump().length() << " bytes\n";

bool exception = false;
try {

TIME_MEASURE_START
auto res =
http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
TIME_MEASURE_STOP
} catch (std::exception &e) {
exception = true;
std::cout << e.what() << std::endl;
}
ASSERT_TRUE(exception);
}

{ // API: Destroy session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), "");
Expand All @@ -113,6 +137,96 @@ TEST(test_onnxruntime_server_http, HttpsServerTest) {
server_thread.join();
}

TEST(test_onnxruntime_server_http, HttpsServerLargeRequestTest) {
Orts::config config;
config.https_port = 0;
config.https_cert = (test_dir / "ssl" / "server-cert.pem").string();
config.https_key = (test_dir / "ssl" / "server-key.pem").string();
config.model_bin_getter = test_model_bin_getter;
config.request_payload_limit = 1024 * 1024 * 1024;

boost::asio::io_context io_context;
Orts::onnx::session_manager manager(config.model_bin_getter);
Orts::builtin_thread_pool worker_pool(config.num_threads);
Orts::transport::http::https_server server(io_context, config, &manager, &worker_pool);

bool running = true;
std::thread server_thread([&io_context, &running]() { test_server_run(io_context, &running); });

TIME_MEASURE_INIT

{ // API: Create session
json body = json::parse(R"({"model":"sample","version":"1"})");
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::post, "/api/sessions", server.port(), body.dump());
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Create session\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json["model"], "sample");
ASSERT_EQ(res_json["version"], "1");
}

{ // API: Get session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::get, "/api/sessions/sample/1", server.port(), "");
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Get session\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json["model"], "sample");
ASSERT_EQ(res_json["version"], "1");
}

{ // API: List session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), "");
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json.size(), 1);
ASSERT_EQ(res_json[0]["model"], "sample");
ASSERT_EQ(res_json[0]["version"], "1");
}

{ // API: Execute session
auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
std::cout << "API: Execute sessions\n" << res_json.dump(2) << "\n";
ASSERT_TRUE(res_json.contains("output"));
ASSERT_EQ(res_json["output"].size(), 1);
ASSERT_GT(res_json["output"][0], 0);
}

{ // API: Execute session large request
auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
int size = 1000000;
for (int i = 0; i < size; i++) {
input["x"].push_back(input["x"][0]);
input["y"].push_back(input["y"][0]);
input["z"].push_back(input["z"][0]);
}
std::cout << input.dump().length() << " bytes\n";

TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
TIME_MEASURE_STOP
ASSERT_EQ(res.result(), boost::beast::http::status::ok);
json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
ASSERT_TRUE(res_json.contains("output"));
ASSERT_EQ(res_json["output"].size(), size + 1);
ASSERT_GT(res_json["output"][0], 0);
}

running = false;
server_thread.join();
}

beast::http::response<beast::http::dynamic_body>
http_request(beast::http::verb method, const std::string &target, short port, std::string body) {
boost::asio::io_context ioc;
Expand Down Expand Up @@ -143,7 +257,5 @@ http_request(beast::http::verb method, const std::string &target, short port, st

beast::http::read(stream, buffer, res);

stream.shutdown();
stream.lowest_layer().close();
return res;
}
2 changes: 1 addition & 1 deletion src/transport/http/http_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ onnxruntime_server::transport::http::http_server::http_server(
onnxruntime_server::onnx::session_manager *onnx_session_manager,
onnxruntime_server::builtin_thread_pool *worker_pool
)
: server(io_context, onnx_session_manager, worker_pool, config.http_port), swagger(config.swagger_url_path) {
: server(io_context, onnx_session_manager, worker_pool, config.http_port, config.request_payload_limit), swagger(config.swagger_url_path) {
acceptor.set_option(boost::asio::socket_base::reuse_address(true));
}

Expand Down
2 changes: 1 addition & 1 deletion src/transport/http/http_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace onnxruntime_server::transport::http {
template <class Session> class http_session_base : public std::enable_shared_from_this<Session> {
protected:
beast::flat_buffer buffer;
beast::http::request<beast::http::string_body> req;
std::shared_ptr<beast::http::request_parser<beast::http::string_body>> req_parser;

virtual onnx::session_manager *get_onnx_session_manager() = 0;
std::shared_ptr<beast::http::response<beast::http::string_body>> handle_request();
Expand Down
Loading
Loading