Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor session_manager::sessions type #29

Merged
merged 2 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/onnx/execution/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "../../onnxruntime_server.hpp"

Orts::onnx::execution::context::context(Orts::onnx::session *session, const json &json_str)
Orts::onnx::execution::context::context(std::shared_ptr<Orts::onnx::session> session, const json &json_str)
: memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)), session(session) {
assert(session != nullptr);

Expand Down
12 changes: 4 additions & 8 deletions src/onnx/session_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,23 @@ Orts::onnx::session_manager::session_manager(model_bin_getter_t model_bin_getter
}

Orts::onnx::session_manager::~session_manager() {
for (auto &it : sessions) {
delete it.second;
}
}

Orts::onnx::session *
std::shared_ptr<Orts::onnx::session>
Orts::onnx::session_manager::get_session(const std::string &model_name, const std::string &model_version) {
auto key = session_key(model_name, model_version);
return get_session(key);
}

Orts::onnx::session *Orts::onnx::session_manager::get_session(const Orts::onnx::session_key &key) {
std::shared_ptr<Orts::onnx::session> Orts::onnx::session_manager::get_session(const Orts::onnx::session_key &key) {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto it = sessions.find(key);
if (it == sessions.end())
return nullptr;
return it->second;
}

Orts::onnx::session *Orts::onnx::session_manager::create_session(
std::shared_ptr<Orts::onnx::session> Orts::onnx::session_manager::create_session(
const std::string &model_name, const std::string &model_version, const json &option, const char *model_data,
size_t model_data_length
) {
Expand All @@ -49,7 +46,7 @@ Orts::onnx::session *Orts::onnx::session_manager::create_session(
if (current_session != nullptr)
throw conflict_error("session already exists");

auto session = new onnx::session(key, model_data, model_data_length, option);
auto session = std::make_shared<onnx::session>(key, model_data, model_data_length, option);
sessions.emplace(key, session);
return session;
}
Expand All @@ -67,6 +64,5 @@ void Orts::onnx::session_manager::remove_session(const Orts::onnx::session_key &
if (it == sessions.end()) {
throw not_found_error("session not found");
}
delete it->second;
sessions.erase(it);
}
14 changes: 7 additions & 7 deletions src/onnxruntime_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,20 @@ namespace onnxruntime_server {
class session_manager {
private:
std::recursive_mutex mutex;
std::map<session_key, session *> sessions;
std::map<session_key, std::shared_ptr<session>> sessions;
model_bin_getter_t model_bin_getter;

public:
explicit session_manager(model_bin_getter_t model_bin_getter);
~session_manager();

std::map<session_key, session *> &get_sessions() {
std::map<session_key, std::shared_ptr<session>> &get_sessions() {
return sessions;
}

session *get_session(const std::string &model_name, const std::string &model_version);
session *get_session(const session_key &key);
session *create_session(
std::shared_ptr<session> get_session(const std::string &model_name, const std::string &model_version);
std::shared_ptr<session> get_session(const session_key &key);
std::shared_ptr<session> create_session(
const std::string &model_name, const std::string &model_version, const json &option,
const char *model_data = nullptr, size_t model_data_length = 0
);
Expand All @@ -161,11 +161,11 @@ namespace onnxruntime_server {
class context {
private:
Ort::MemoryInfo memory_info;
onnxruntime_server::onnx::session *session;
std::shared_ptr<onnxruntime_server::onnx::session> session;
std::map<std::string, input_value *> inputs;

public:
context(class session *session, const json &json_str);
context(std::shared_ptr<class session> session, const json &json_str);
~context();

void flat_json_values(const json::value_type &data, std::vector<json::value_type> *json_values);
Expand Down
2 changes: 1 addition & 1 deletion src/task/create_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Orts::task::create_session::create_session(
}

json Orts::task::create_session::run() {
Orts::onnx::session *session =
std::shared_ptr<Orts::onnx::session> session =
onnx_session_manager->create_session(model_name, model_version, option, model_data, model_data_length);
if (session == nullptr) {
throw not_found_error("session not found");
Expand Down
14 changes: 7 additions & 7 deletions src/test/unit/unit_test_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

TEST(test_onnxruntime_server_context, SimpleModelTest) {
Orts::onnx::session_key key("sample", "1");
Orts::onnx::session session(key, model1_path.string());
auto session = std::make_shared<Orts::onnx::session>(key, model1_path.string());

Orts::onnx::execution::context ctx(&session, R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
Orts::onnx::execution::context ctx(session, R"({"x":[[1]],"y":[[2]],"z":[[3]]})");

TIME_MEASURE_INIT
TIME_MEASURE_START
Expand All @@ -24,9 +24,9 @@ TEST(test_onnxruntime_server_context, SimpleModelTest) {

TEST(test_onnxruntime_server_context, SimpleModelBatchTest) {
Orts::onnx::session_key key("sample", "1");
Orts::onnx::session session(key, model1_path.string());
auto session = std::make_shared<Orts::onnx::session>(key, model1_path.string());

Orts::onnx::execution::context ctx(&session, R"({"x":[[1],[2],[3]],"y":[[2],[3],[4]],"z":[[3],[4],[5]]})");
Orts::onnx::execution::context ctx(session, R"({"x":[[1],[2],[3]],"y":[[2],[3],[4]],"z":[[3],[4],[5]]})");

TIME_MEASURE_INIT
TIME_MEASURE_START
Expand All @@ -42,11 +42,11 @@ TEST(test_onnxruntime_server_context, SimpleModelBatchTest) {

TEST(test_onnxruntime_server_context, BertSquadModelTest) {
Orts::onnx::session_key key("sample", "2");
Orts::onnx::session session(key, model2_path.string());
auto session = std::make_shared<Orts::onnx::session>(key, model2_path.string());

std::cout << session.to_json().dump(2) << "\n";
std::cout << session->to_json().dump(2) << "\n";

Orts::onnx::execution::context ctx(&session, R"({
Orts::onnx::execution::context ctx(session, R"({
"input_ids": [
[101, 11834, 21600, 2102, 9005, 12098, 8566, 5740, 6853, 1999, 1996, 2806, 1997, 15262, 19699, 14663, 1005, 1055, 3203, 8447, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[101, 3795, 3361, 3422, 16168, 19428, 1997, 1520, 2582, 7860, 1998, 28215, 1521, 3805, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
Expand Down
6 changes: 3 additions & 3 deletions src/test/unit/unit_test_context_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

TEST(test_onnxruntime_server_context_cuda, BertSquadModelTest) {
Orts::onnx::session_key key("sample", "2");
Orts::onnx::session session(key, model2_path.string(), json::parse(R"({"cuda": true})"));
auto session = std::make_shared<Orts::onnx::session>(key, model2_path.string(), json::parse(R"({"cuda": true})"));

std::cout << session.to_json().dump(2) << "\n";
std::cout << session->to_json().dump(2) << "\n";

Orts::onnx::execution::context ctx(&session, R"({
Orts::onnx::execution::context ctx(session, R"({
"input_ids": [
[101, 11834, 21600, 2102, 9005, 12098, 8566, 5740, 6853, 1999, 1996, 2806, 1997, 15262, 19699, 14663, 1005, 1055, 3203, 8447, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[101, 3795, 3361, 3422, 16168, 19428, 1997, 1520, 2582, 7860, 1998, 28215, 1521, 3805, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
Expand Down