From 97f2e23e4bc0c455d151c6156fe1cab1ecec0e0c Mon Sep 17 00:00:00 2001 From: Kibae Shin Date: Fri, 28 Jun 2024 22:37:16 +0900 Subject: [PATCH] fixes #54: Added --request-payload-limit option for large request support (#58) * fixes #54: Added --request-payload-limit option for large request support --- README.md | 17 +-- src/onnxruntime_server.hpp | 5 +- src/standalone/standalone.cpp | 11 +- src/test/e2e/e2e_test_http_server.cpp | 132 ++++++++++++++++++++++- src/test/e2e/e2e_test_https_server.cpp | 116 +++++++++++++++++++- src/transport/http/http_server.cpp | 2 +- src/transport/http/http_server.hpp | 2 +- src/transport/http/http_session.cpp | 11 +- src/transport/http/http_session_base.cpp | 6 +- src/transport/http/https_server.cpp | 2 +- src/transport/http/https_session.cpp | 7 +- src/transport/server.cpp | 8 +- src/transport/tcp/tcp_server.cpp | 2 +- 13 files changed, 293 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 7a8bbe2..406fb93 100644 --- a/README.md +++ b/README.md @@ -95,9 +95,9 @@ sudo cmake --install build --prefix /usr/local/onnxruntime-server # Install via a package manager -| OS | Method | Command | -|-------------------------------|------------|--------------------------------------------| -| Arch Linux | AUR | `yay -S onnxruntime-server` | +| OS | Method | Command | +|------------|--------|-----------------------------| +| Arch Linux | AUR | `yay -S onnxruntime-server` | ---- @@ -127,11 +127,12 @@ sudo cmake --install build --prefix /usr/local/onnxruntime-server ## Options -| Option | Environment | Description | -|-------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `--workers` | `ONNX_SERVER_WORKERS` | Worker thread pool size.
Default: `4` | -| `--model-dir` | `ONNX_SERVER_MODEL_DIR` | Model directory path
The onnx model files must be located in the following path:
`${model_dir}/${model_name}/${model_version}/model.onnx`
Default: `models` | -| `--prepare-model` | `ONNX_SERVER_PREPARE_MODEL` | Pre-create some model sessions at server startup.

Format as a space-separated list of `model_name:model_version` or `model_name:model_version(session_options, ...)`.

Available session_options are
- cuda=device_id`[ or true or false]`

eg) `model1:v1 model2:v9`
`model1:v1(cuda=true) model2:v9(cuda=1)` | +| Option | Environment | Description | +|---------------------------|-------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `--workers` | `ONNX_SERVER_WORKERS` | Worker thread pool size.
Default: `4` | +| `--request-payload-limit` | `ONNX_SERVER_REQUEST_PAYLOAD_LIMIT` | HTTP/HTTPS request payload size limit.
Default: 1024 * 1024 * 10(10MB)` | +| `--model-dir` | `ONNX_SERVER_MODEL_DIR` | Model directory path
The onnx model files must be located in the following path:
`${model_dir}/${model_name}/${model_version}/model.onnx`
Default: `models` | +| `--prepare-model` | `ONNX_SERVER_PREPARE_MODEL` | Pre-create some model sessions at server startup.

Format as a space-separated list of `model_name:model_version` or `model_name:model_version(session_options, ...)`.

Available session_options are
- cuda=device_id`[ or true or false]`

eg) `model1:v1 model2:v9`
`model1:v1(cuda=true) model2:v9(cuda=1)` | ### Backend options diff --git a/src/onnxruntime_server.hpp b/src/onnxruntime_server.hpp index d0e8743..0a85e39 100644 --- a/src/onnxruntime_server.hpp +++ b/src/onnxruntime_server.hpp @@ -311,6 +311,7 @@ namespace onnxruntime_server { std::string model_dir; std::string prepare_model; model_bin_getter_t model_bin_getter{}; + long request_payload_limit = 1024 * 1024 * 10; }; namespace transport { @@ -322,6 +323,7 @@ namespace onnxruntime_server { asio::socket socket; asio::acceptor acceptor; uint_least16_t assigned_port = 0; + long request_payload_limit_; onnx::session_manager *onnx_session_manager; @@ -331,12 +333,13 @@ namespace onnxruntime_server { public: server( boost::asio::io_context &io_context, onnx::session_manager *onnx_session_manager, - builtin_thread_pool *worker_pool, int port + builtin_thread_pool *worker_pool, int port, long request_payload_limit ); ~server(); builtin_thread_pool *get_worker_pool(); onnx::session_manager *get_onnx_session_manager(); + [[nodiscard]] long request_payload_limit() const; [[nodiscard]] uint_least16_t port() const; }; diff --git a/src/standalone/standalone.cpp b/src/standalone/standalone.cpp index e41abbc..df7db1c 100644 --- a/src/standalone/standalone.cpp +++ b/src/standalone/standalone.cpp @@ -15,9 +15,13 @@ int onnxruntime_server::standalone::init_config(int argc, char **argv) { po_desc.add_options()("help,h", "Produce help message\n"); // env: ONNX_WORKERS po_desc.add_options()( - "workers", po::value()->default_value(4), + "workers", po::value()->default_value(4), "env: ONNX_SERVER_WORKERS\nWorker thread pool size.\nDefault: 4" ); + po_desc.add_options()( + "request-payload-limit", po::value()->default_value(1024 * 1024 * 10), + "env: ONNX_SERVER_REQUEST_PAYLOAD_LIMIT\nHTTP/HTTPS request payload size limit.\nDefault: 1024 * 1024 * 10(10MB)" + ); po_desc.add_options()( "model-dir", po::value()->default_value("models"), "env: ONNX_SERVER_MODEL_DIR\nModel directory path.\nThe onnx model files must be located in the " @@ -156,7 +160,10 @@ int onnxruntime_server::standalone::init_config(int argc, char **argv) { AixLog::Log::init({log_file, log_access_file}); if (vm.count("workers")) - config.num_threads = vm["workers"].as(); + config.num_threads = vm["workers"].as(); + + if (vm.count("request-payload-limit")) + config.request_payload_limit = vm["request-payload-limit"].as(); if (vm.count("model-dir")) config.model_dir = vm["model-dir"].as(); diff --git a/src/test/e2e/e2e_test_http_server.cpp b/src/test/e2e/e2e_test_http_server.cpp index c0f84a5..6a4aa58 100644 --- a/src/test/e2e/e2e_test_http_server.cpp +++ b/src/test/e2e/e2e_test_http_server.cpp @@ -87,6 +87,137 @@ TEST(test_onnxruntime_server_http, HttpServerTest) { ASSERT_GT(res_json["output"][0], 0); } + { // API: Execute session large request + auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + int size = 1000000; + for (int i = 0; i < size; i++) { + input["x"].push_back(input["x"][0]); + input["y"].push_back(input["y"][0]); + input["z"].push_back(input["z"][0]); + } + std::cout << input.dump().length() << " bytes\n"; + + bool exception = false; + try { + TIME_MEASURE_START + auto res = + http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump()); + TIME_MEASURE_STOP + } catch (std::exception &e) { + exception = true; + std::cout << e.what() << std::endl; + } + ASSERT_TRUE(exception); + } + + { // API: Destroy session + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), ""); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Destroy sessions\n" << res_json.dump(2) << "\n"; + ASSERT_TRUE(res_json); + } + + { // API: List session + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), ""); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: List sessions\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json.size(), 0); + } + + running = false; + server_thread.join(); +} + +TEST(test_onnxruntime_server_http, HttpServerLargeRequestTest) { + Orts::config config; + config.http_port = 0; + config.model_bin_getter = test_model_bin_getter; + config.request_payload_limit = 1024 * 1024 * 1024; + + boost::asio::io_context io_context; + Orts::onnx::session_manager manager(config.model_bin_getter); + Orts::builtin_thread_pool worker_pool(config.num_threads); + Orts::transport::http::http_server server(io_context, config, &manager, &worker_pool); + + bool running = true; + std::thread server_thread([&io_context, &running]() { test_server_run(io_context, &running); }); + + TIME_MEASURE_INIT + + { // API: Create session + json body = json::parse(R"({"model":"sample","version":"1"})"); + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::post, "/api/sessions", server.port(), body.dump()); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Create session\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json["model"], "sample"); + ASSERT_EQ(res_json["version"], "1"); + } + + { // API: Get session + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::get, "/api/sessions/sample/1", server.port(), ""); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Get session\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json["model"], "sample"); + ASSERT_EQ(res_json["version"], "1"); + } + + { // API: List session + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), ""); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: List sessions\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json.size(), 1); + ASSERT_EQ(res_json[0]["model"], "sample"); + ASSERT_EQ(res_json[0]["version"], "1"); + } + + { // API: Execute session + auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump()); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Execute sessions\n" << res_json.dump(2) << "\n"; + ASSERT_TRUE(res_json.contains("output")); + ASSERT_EQ(res_json["output"].size(), 1); + ASSERT_GT(res_json["output"][0], 0); + } + + { // API: Execute session large request + auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + int size = 1000000; + for (int i = 0; i < size; i++) { + input["x"].push_back(input["x"][0]); + input["y"].push_back(input["y"][0]); + input["z"].push_back(input["z"][0]); + } + std::cout << input.dump().length() << " bytes\n"; + + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump()); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + ASSERT_TRUE(res_json.contains("output")); + ASSERT_EQ(res_json["output"].size(), size + 1); + ASSERT_GT(res_json["output"][0], 0); + } + { // API: Destroy session TIME_MEASURE_START auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), ""); @@ -132,6 +263,5 @@ http_request(beast::http::verb method, const std::string &target, short port, st beast::http::read(socket, buffer, res); - socket.close(); return res; } diff --git a/src/test/e2e/e2e_test_https_server.cpp b/src/test/e2e/e2e_test_https_server.cpp index 28d886f..4f0b1b4 100644 --- a/src/test/e2e/e2e_test_https_server.cpp +++ b/src/test/e2e/e2e_test_https_server.cpp @@ -89,6 +89,30 @@ TEST(test_onnxruntime_server_http, HttpsServerTest) { ASSERT_GT(res_json["output"][0], 0); } + { // API: Execute session large request + auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + int size = 1000000; + for (int i = 0; i < size; i++) { + input["x"].push_back(input["x"][0]); + input["y"].push_back(input["y"][0]); + input["z"].push_back(input["z"][0]); + } + std::cout << input.dump().length() << " bytes\n"; + + bool exception = false; + try { + + TIME_MEASURE_START + auto res = + http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump()); + TIME_MEASURE_STOP + } catch (std::exception &e) { + exception = true; + std::cout << e.what() << std::endl; + } + ASSERT_TRUE(exception); + } + { // API: Destroy session TIME_MEASURE_START auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), ""); @@ -113,6 +137,96 @@ TEST(test_onnxruntime_server_http, HttpsServerTest) { server_thread.join(); } +TEST(test_onnxruntime_server_http, HttpsServerLargeRequestTest) { + Orts::config config; + config.https_port = 0; + config.https_cert = (test_dir / "ssl" / "server-cert.pem").string(); + config.https_key = (test_dir / "ssl" / "server-key.pem").string(); + config.model_bin_getter = test_model_bin_getter; + config.request_payload_limit = 1024 * 1024 * 1024; + + boost::asio::io_context io_context; + Orts::onnx::session_manager manager(config.model_bin_getter); + Orts::builtin_thread_pool worker_pool(config.num_threads); + Orts::transport::http::https_server server(io_context, config, &manager, &worker_pool); + + bool running = true; + std::thread server_thread([&io_context, &running]() { test_server_run(io_context, &running); }); + + TIME_MEASURE_INIT + + { // API: Create session + json body = json::parse(R"({"model":"sample","version":"1"})"); + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::post, "/api/sessions", server.port(), body.dump()); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Create session\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json["model"], "sample"); + ASSERT_EQ(res_json["version"], "1"); + } + + { // API: Get session + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::get, "/api/sessions/sample/1", server.port(), ""); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Get session\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json["model"], "sample"); + ASSERT_EQ(res_json["version"], "1"); + } + + { // API: List session + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), ""); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: List sessions\n" << res_json.dump(2) << "\n"; + ASSERT_EQ(res_json.size(), 1); + ASSERT_EQ(res_json[0]["model"], "sample"); + ASSERT_EQ(res_json[0]["version"], "1"); + } + + { // API: Execute session + auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump()); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + std::cout << "API: Execute sessions\n" << res_json.dump(2) << "\n"; + ASSERT_TRUE(res_json.contains("output")); + ASSERT_EQ(res_json["output"].size(), 1); + ASSERT_GT(res_json["output"][0], 0); + } + + { // API: Execute session large request + auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + int size = 1000000; + for (int i = 0; i < size; i++) { + input["x"].push_back(input["x"][0]); + input["y"].push_back(input["y"][0]); + input["z"].push_back(input["z"][0]); + } + std::cout << input.dump().length() << " bytes\n"; + + TIME_MEASURE_START + auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump()); + TIME_MEASURE_STOP + ASSERT_EQ(res.result(), boost::beast::http::status::ok); + json res_json = json::parse(boost::beast::buffers_to_string(res.body().data())); + ASSERT_TRUE(res_json.contains("output")); + ASSERT_EQ(res_json["output"].size(), size + 1); + ASSERT_GT(res_json["output"][0], 0); + } + + running = false; + server_thread.join(); +} + beast::http::response http_request(beast::http::verb method, const std::string &target, short port, std::string body) { boost::asio::io_context ioc; @@ -143,7 +257,5 @@ http_request(beast::http::verb method, const std::string &target, short port, st beast::http::read(stream, buffer, res); - stream.shutdown(); - stream.lowest_layer().close(); return res; } diff --git a/src/transport/http/http_server.cpp b/src/transport/http/http_server.cpp index c0d2efb..606ea4f 100644 --- a/src/transport/http/http_server.cpp +++ b/src/transport/http/http_server.cpp @@ -9,7 +9,7 @@ onnxruntime_server::transport::http::http_server::http_server( onnxruntime_server::onnx::session_manager *onnx_session_manager, onnxruntime_server::builtin_thread_pool *worker_pool ) - : server(io_context, onnx_session_manager, worker_pool, config.http_port), swagger(config.swagger_url_path) { + : server(io_context, onnx_session_manager, worker_pool, config.http_port, config.request_payload_limit), swagger(config.swagger_url_path) { acceptor.set_option(boost::asio::socket_base::reuse_address(true)); } diff --git a/src/transport/http/http_server.hpp b/src/transport/http/http_server.hpp index e3e1f72..03ef00d 100644 --- a/src/transport/http/http_server.hpp +++ b/src/transport/http/http_server.hpp @@ -36,7 +36,7 @@ namespace onnxruntime_server::transport::http { template class http_session_base : public std::enable_shared_from_this { protected: beast::flat_buffer buffer; - beast::http::request req; + std::shared_ptr> req_parser; virtual onnx::session_manager *get_onnx_session_manager() = 0; std::shared_ptr> handle_request(); diff --git a/src/transport/http/http_session.cpp b/src/transport/http/http_session.cpp index 5281eb4..1a75fc7 100644 --- a/src/transport/http/http_session.cpp +++ b/src/transport/http/http_session.cpp @@ -30,10 +30,12 @@ void onnxruntime_server::transport::http::http_session::do_read() { _remote_endpoint = stream.socket().remote_endpoint().address().to_string() + ":" + std::to_string(stream.socket().remote_endpoint().port()); - req = {}; - stream.expires_after(std::chrono::seconds(30)); + // stream.expires_after(std::chrono::seconds(300)); + req_parser->body_limit(server->request_payload_limit()); - beast::http::async_read(stream, buffer, req, beast::bind_front_handler(&http_session::on_read, shared_from_this())); + beast::http::async_read( + stream, buffer, *req_parser, beast::bind_front_handler(&http_session::on_read, shared_from_this()) + ); } void onnxruntime_server::transport::http::http_session::on_read(beast::error_code ec, std::size_t bytes_transferred) { @@ -55,6 +57,7 @@ void onnxruntime_server::transport::http::http_session::on_read(beast::error_cod void onnxruntime_server::transport::http::http_session::do_write( std::shared_ptr> msg ) { + auto req = req_parser->get(); PLOG(L_INFO, "ACCESS") << get_remote_endpoint() << " task: " << req.method_string() << " " << req.target() << " status: " << msg->result_int() << " duration: " << request_time.get_duration() << std::endl; @@ -68,7 +71,7 @@ void onnxruntime_server::transport::http::http_session::do_write( return self->close(); } - if (!self->req.keep_alive()) + if (!self->req_parser->get().keep_alive()) return self->close(); self->do_read(); diff --git a/src/transport/http/http_session_base.cpp b/src/transport/http/http_session_base.cpp index 8000249..6cab8f1 100644 --- a/src/transport/http/http_session_base.cpp +++ b/src/transport/http/http_session_base.cpp @@ -1,7 +1,8 @@ #include "http_server.hpp" template -onnxruntime_server::transport::http::http_session_base::http_session_base() : buffer(), req() { +onnxruntime_server::transport::http::http_session_base::http_session_base() : buffer() { + req_parser = std::make_shared>(); } #define CONTENT_TYPE_PLAIN_TEXT "text/plain" @@ -10,8 +11,11 @@ onnxruntime_server::transport::http::http_session_base::http_session_ba template std::shared_ptr> onnxruntime_server::transport::http::http_session_base::handle_request() { + auto req = req_parser->get(); + auto const simple_response = [this](beast::http::status method, beast::string_view content_type, beast::string_view body) { + auto req = req_parser->get(); auto res = std::make_shared>(method, req.version()); res->set(beast::http::field::content_type, content_type); res->keep_alive(req.keep_alive()); diff --git a/src/transport/http/https_server.cpp b/src/transport/http/https_server.cpp index 34ac785..f45e718 100644 --- a/src/transport/http/https_server.cpp +++ b/src/transport/http/https_server.cpp @@ -9,7 +9,7 @@ onnxruntime_server::transport::http::https_server::https_server( onnxruntime_server::onnx::session_manager *onnx_session_manager, onnxruntime_server::builtin_thread_pool *worker_pool ) - : server(io_context, onnx_session_manager, worker_pool, config.https_port), ctx(boost::asio::ssl::context::sslv23), + : server(io_context, onnx_session_manager, worker_pool, config.https_port, config.request_payload_limit), ctx(boost::asio::ssl::context::sslv23), swagger(config.swagger_url_path) { boost::system::error_code ec; ctx.set_options( diff --git a/src/transport/http/https_session.cpp b/src/transport/http/https_session.cpp index 2eb33ee..6f39227 100644 --- a/src/transport/http/https_session.cpp +++ b/src/transport/http/https_session.cpp @@ -35,10 +35,10 @@ void onnxruntime_server::transport::http::https_session::do_read() { _remote_endpoint = stream.lowest_layer().remote_endpoint().address().to_string() + ":" + std::to_string(stream.lowest_layer().remote_endpoint().port()); - req = {}; + req_parser->body_limit(server->request_payload_limit()); beast::http::async_read( - stream, buffer, req, beast::bind_front_handler(&https_session::on_read, shared_from_this()) + stream, buffer, *req_parser, beast::bind_front_handler(&https_session::on_read, shared_from_this()) ); } @@ -59,6 +59,7 @@ void onnxruntime_server::transport::http::https_session::on_read(beast::error_co void onnxruntime_server::transport::http::https_session::do_write( std::shared_ptr> msg ) { + auto req = req_parser->get(); PLOG(L_INFO, "ACCESS") << get_remote_endpoint() << " task: " << req.method_string() << " " << req.target() << " status: " << msg->result_int() << " duration: " << request_time.get_duration() << std::endl; @@ -72,7 +73,7 @@ void onnxruntime_server::transport::http::https_session::do_write( return self->close(); } - if (!self->req.keep_alive()) + if (!self->req_parser->get().keep_alive()) return self->close(); self->do_read(); diff --git a/src/transport/server.cpp b/src/transport/server.cpp index cd97aef..6f95f2c 100644 --- a/src/transport/server.cpp +++ b/src/transport/server.cpp @@ -6,10 +6,10 @@ Orts::transport::server::server( boost::asio::io_context &io_context, Orts::onnx::session_manager *onnx_session_manager, - Orts::builtin_thread_pool *worker_pool, int port + Orts::builtin_thread_pool *worker_pool, int port, long request_payload_limit ) : io_context(io_context), acceptor(io_context, asio::endpoint(asio::v4(), port)), socket(io_context), - onnx_session_manager(onnx_session_manager), worker_pool(worker_pool) { + onnx_session_manager(onnx_session_manager), worker_pool(worker_pool), request_payload_limit_(request_payload_limit) { assigned_port = acceptor.local_endpoint().port(); @@ -37,6 +37,10 @@ Orts::onnx::session_manager *Orts::transport::server::get_onnx_session_manager() return onnx_session_manager; } +long Orts::transport::server::request_payload_limit() const { + return request_payload_limit_; +} + uint_least16_t Orts::transport::server::port() const { return assigned_port; } diff --git a/src/transport/tcp/tcp_server.cpp b/src/transport/tcp/tcp_server.cpp index 9cb3fee..42baef8 100644 --- a/src/transport/tcp/tcp_server.cpp +++ b/src/transport/tcp/tcp_server.cpp @@ -8,7 +8,7 @@ onnxruntime_server::transport::tcp::tcp_server::tcp_server( onnxruntime_server::onnx::session_manager *onnx_session_manager, onnxruntime_server::builtin_thread_pool *worker_pool ) - : server(io_context, onnx_session_manager, worker_pool, config.tcp_port) { + : server(io_context, onnx_session_manager, worker_pool, config.tcp_port, config.request_payload_limit) { acceptor.set_option(boost::asio::socket_base::reuse_address(true)); }