diff --git a/src/onnx/execution/context.cpp b/src/onnx/execution/context.cpp index 889413f..5996c59 100644 --- a/src/onnx/execution/context.cpp +++ b/src/onnx/execution/context.cpp @@ -4,7 +4,7 @@ #include "../../onnxruntime_server.hpp" -Orts::onnx::execution::context::context(Orts::onnx::session *session, const json &json_str) +Orts::onnx::execution::context::context(std::shared_ptr session, const json &json_str) : memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)), session(session) { assert(session != nullptr); diff --git a/src/onnx/session_manager.cpp b/src/onnx/session_manager.cpp index e5042aa..ea829f5 100644 --- a/src/onnx/session_manager.cpp +++ b/src/onnx/session_manager.cpp @@ -9,18 +9,15 @@ Orts::onnx::session_manager::session_manager(model_bin_getter_t model_bin_getter } Orts::onnx::session_manager::~session_manager() { - for (auto &it : sessions) { - delete it.second; - } } -Orts::onnx::session * +std::shared_ptr Orts::onnx::session_manager::get_session(const std::string &model_name, const std::string &model_version) { auto key = session_key(model_name, model_version); return get_session(key); } -Orts::onnx::session *Orts::onnx::session_manager::get_session(const Orts::onnx::session_key &key) { +std::shared_ptr Orts::onnx::session_manager::get_session(const Orts::onnx::session_key &key) { std::lock_guard lock(mutex); auto it = sessions.find(key); if (it == sessions.end()) @@ -28,7 +25,7 @@ Orts::onnx::session *Orts::onnx::session_manager::get_session(const Orts::onnx:: return it->second; } -Orts::onnx::session *Orts::onnx::session_manager::create_session( +std::shared_ptr Orts::onnx::session_manager::create_session( const std::string &model_name, const std::string &model_version, const json &option, const char *model_data, size_t model_data_length ) { @@ -49,7 +46,7 @@ Orts::onnx::session *Orts::onnx::session_manager::create_session( if (current_session != nullptr) throw conflict_error("session already exists"); - auto session = new onnx::session(key, model_data, model_data_length, option); + auto session = std::make_shared(key, model_data, model_data_length, option); sessions.emplace(key, session); return session; } @@ -67,6 +64,5 @@ void Orts::onnx::session_manager::remove_session(const Orts::onnx::session_key & if (it == sessions.end()) { throw not_found_error("session not found"); } - delete it->second; sessions.erase(it); } diff --git a/src/onnxruntime_server.hpp b/src/onnxruntime_server.hpp index 9cfe2bf..82ce47e 100644 --- a/src/onnxruntime_server.hpp +++ b/src/onnxruntime_server.hpp @@ -123,20 +123,20 @@ namespace onnxruntime_server { class session_manager { private: std::recursive_mutex mutex; - std::map sessions; + std::map> sessions; model_bin_getter_t model_bin_getter; public: explicit session_manager(model_bin_getter_t model_bin_getter); ~session_manager(); - std::map &get_sessions() { + std::map> &get_sessions() { return sessions; } - session *get_session(const std::string &model_name, const std::string &model_version); - session *get_session(const session_key &key); - session *create_session( + std::shared_ptr get_session(const std::string &model_name, const std::string &model_version); + std::shared_ptr get_session(const session_key &key); + std::shared_ptr create_session( const std::string &model_name, const std::string &model_version, const json &option, const char *model_data = nullptr, size_t model_data_length = 0 ); @@ -161,11 +161,11 @@ namespace onnxruntime_server { class context { private: Ort::MemoryInfo memory_info; - onnxruntime_server::onnx::session *session; + std::shared_ptr session; std::map inputs; public: - context(class session *session, const json &json_str); + context(std::shared_ptr session, const json &json_str); ~context(); void flat_json_values(const json::value_type &data, std::vector *json_values); diff --git a/src/task/create_session.cpp b/src/task/create_session.cpp index 4eba673..10a9df6 100644 --- a/src/task/create_session.cpp +++ b/src/task/create_session.cpp @@ -32,7 +32,7 @@ Orts::task::create_session::create_session( } json Orts::task::create_session::run() { - Orts::onnx::session *session = + std::shared_ptr session = onnx_session_manager->create_session(model_name, model_version, option, model_data, model_data_length); if (session == nullptr) { throw not_found_error("session not found"); diff --git a/src/test/unit/unit_test_context.cpp b/src/test/unit/unit_test_context.cpp index 4122d8b..45728f6 100644 --- a/src/test/unit/unit_test_context.cpp +++ b/src/test/unit/unit_test_context.cpp @@ -6,9 +6,9 @@ TEST(test_onnxruntime_server_context, SimpleModelTest) { Orts::onnx::session_key key("sample", "1"); - Orts::onnx::session session(key, model1_path.string()); + auto session = std::make_shared(key, model1_path.string()); - Orts::onnx::execution::context ctx(&session, R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); + Orts::onnx::execution::context ctx(session, R"({"x":[[1]],"y":[[2]],"z":[[3]]})"); TIME_MEASURE_INIT TIME_MEASURE_START @@ -24,9 +24,9 @@ TEST(test_onnxruntime_server_context, SimpleModelTest) { TEST(test_onnxruntime_server_context, SimpleModelBatchTest) { Orts::onnx::session_key key("sample", "1"); - Orts::onnx::session session(key, model1_path.string()); + auto session = std::make_shared(key, model1_path.string()); - Orts::onnx::execution::context ctx(&session, R"({"x":[[1],[2],[3]],"y":[[2],[3],[4]],"z":[[3],[4],[5]]})"); + Orts::onnx::execution::context ctx(session, R"({"x":[[1],[2],[3]],"y":[[2],[3],[4]],"z":[[3],[4],[5]]})"); TIME_MEASURE_INIT TIME_MEASURE_START @@ -42,11 +42,11 @@ TEST(test_onnxruntime_server_context, SimpleModelBatchTest) { TEST(test_onnxruntime_server_context, BertSquadModelTest) { Orts::onnx::session_key key("sample", "2"); - Orts::onnx::session session(key, model2_path.string()); + auto session = std::make_shared(key, model2_path.string()); - std::cout << session.to_json().dump(2) << "\n"; + std::cout << session->to_json().dump(2) << "\n"; - Orts::onnx::execution::context ctx(&session, R"({ + Orts::onnx::execution::context ctx(session, R"({ "input_ids": [ [101, 11834, 21600, 2102, 9005, 12098, 8566, 5740, 6853, 1999, 1996, 2806, 1997, 15262, 19699, 14663, 1005, 1055, 3203, 8447, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 3795, 3361, 3422, 16168, 19428, 1997, 1520, 2582, 7860, 1998, 28215, 1521, 3805, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], diff --git a/src/test/unit/unit_test_context_cuda.cpp b/src/test/unit/unit_test_context_cuda.cpp index 944ef04..077e401 100644 --- a/src/test/unit/unit_test_context_cuda.cpp +++ b/src/test/unit/unit_test_context_cuda.cpp @@ -6,11 +6,11 @@ TEST(test_onnxruntime_server_context_cuda, BertSquadModelTest) { Orts::onnx::session_key key("sample", "2"); - Orts::onnx::session session(key, model2_path.string(), json::parse(R"({"cuda": true})")); + auto session = std::make_shared(key, model2_path.string(), json::parse(R"({"cuda": true})")); - std::cout << session.to_json().dump(2) << "\n"; + std::cout << session->to_json().dump(2) << "\n"; - Orts::onnx::execution::context ctx(&session, R"({ + Orts::onnx::execution::context ctx(session, R"({ "input_ids": [ [101, 11834, 21600, 2102, 9005, 12098, 8566, 5740, 6853, 1999, 1996, 2806, 1997, 15262, 19699, 14663, 1005, 1055, 3203, 8447, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 3795, 3361, 3422, 16168, 19428, 1997, 1520, 2582, 7860, 1998, 28215, 1521, 3805, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],