Skip to content

Commit

Permalink
feat: session from local file (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
kibae authored Sep 17, 2023
1 parent c0d12ee commit a726dcd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
33 changes: 16 additions & 17 deletions src/onnx/session_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,25 @@ std::shared_ptr<Orts::onnx::session> Orts::onnx::session_manager::create_session
) {
auto key = session_key(model_name, model_version);

static std::string model_bin;
if (model_data == nullptr) {
// get model binary from model_bin_getter
model_bin = model_bin_getter(model_name, model_version);
model_data = model_bin.data();
model_data_length = model_bin.size();
}

{
std::lock_guard<std::recursive_mutex> lock(mutex);
std::shared_ptr<Orts::onnx::session> session = nullptr;
std::lock_guard<std::recursive_mutex> lock(mutex);

auto current_session = get_session(key);
if (current_session != nullptr)
throw conflict_error("session already exists");
auto current_session = get_session(key);
if (current_session != nullptr)
throw conflict_error("session already exists");

auto session = std::make_shared<onnx::session>(key, model_data, model_data_length, option);
sessions.emplace(key, session);
return session;
if (model_data != nullptr && model_data_length > 0) {
session = std::make_shared<onnx::session>(key, model_data, model_data_length, option);
} else if (option.contains("path") && option["path"].is_string()) {
session = std::make_shared<onnx::session>(key, option["path"].get<std::string>(), option);
} else {
auto model_bin = model_bin_getter(model_name, model_version);
model_data = model_bin.data();
model_data_length = model_bin.size();
session = std::make_shared<onnx::session>(key, model_data, model_data_length, option);
}
return nullptr;
sessions.emplace(key, session);
return session;
}

void Orts::onnx::session_manager::remove_session(const std::string &model_name, const std::string &model_version) {
Expand Down
25 changes: 23 additions & 2 deletions src/test/e2e/e2e_test_tcp_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,33 @@ TEST(test_onnxruntime_server_tcp, TcpServerTest) {
ASSERT_EQ(res_json["version"], "1");
}

{ // API: Create session2
json body = json::parse(R"({"model":"sample","version":"path"})");
body["option"]["path"] = model2_path.string();
TIME_MEASURE_START
auto res_json = tcp_request(server.port(), Orts::task::type::CREATE_SESSION, body);
TIME_MEASURE_STOP
std::cout << "API: Create session\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json["model"], "sample");
ASSERT_EQ(res_json["version"], "path");
}

{ // API: Get session
json body = json::parse(R"({"model":"sample","version":"path"})");
TIME_MEASURE_START
auto res_json = tcp_request(server.port(), Orts::task::type::GET_SESSION, body);
TIME_MEASURE_STOP
std::cout << "API: Get session\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json["model"], "sample");
ASSERT_EQ(res_json["version"], "path");
}

{ // API: List session
TIME_MEASURE_START
auto res_json = tcp_request(server.port(), Orts::task::type::LIST_SESSION, "");
TIME_MEASURE_STOP
std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json.size(), 1);
ASSERT_EQ(res_json.size(), 2);
ASSERT_EQ(res_json[0]["model"], "sample");
ASSERT_EQ(res_json[0]["version"], "1");
}
Expand Down Expand Up @@ -126,7 +147,7 @@ TEST(test_onnxruntime_server_tcp, TcpServerTest) {
auto res_json = tcp_request(server.port(), Orts::task::type::LIST_SESSION, "");
TIME_MEASURE_STOP
std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
ASSERT_EQ(res_json.size(), 0);
ASSERT_EQ(res_json.size(), 1);
}

running = false;
Expand Down

0 comments on commit a726dcd

Please sign in to comment.