Skip to content

Commit

Permalink
[runtime] Upgrade onnx runtime version to v1.13.1 to address potentia…
Browse files Browse the repository at this point in the history
…l memory leakage issues (#1964)
  • Loading branch information
yangzhengzhe committed Sep 25, 2024
1 parent 2d0da71 commit 5cafbee
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 43 deletions.
12 changes: 6 additions & 6 deletions runtime/core/cmake/onnx.cmake
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
if(ONNX)
set(ONNX_VERSION "1.12.0")
set(ONNX_VERSION "1.13.1")
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip")
set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176")
set(URL_HASH "SHA256=cd8318dc30352e0d615f809bd544bfd18b578289ec16621252b5db1994f09e43")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac")
set(URL_HASH "SHA256=18e441585de69ef8aab263e2e96f0325729537ebfbd17cdcee78b2eabf0594d2")
else()
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5")
set(URL_HASH "SHA256=2c7fdcfa8131b52167b1870747758cb24265952eba975318a67cc840c04ca73e")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-arm64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=23117b6f5d7324d4a7c51184e5f808dd952aec411a6b99a1b6fd1011de06e300")
set(URL_HASH "SHA256=10ce30925c789715f29424a7658b41c601dfbde5d58fe21cb53ad418cde3c215")
else()
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600")
set(URL_HASH "SHA256=32f3fff17b01db779e9e3cbe32f27adba40460e6202a79dfd1ac76b4f20588ef")
endif()
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
Expand Down
88 changes: 56 additions & 32 deletions runtime/core/decoder/onnx_asr_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ void OnnxAsrModel::InitEngineThreads(int num_threads) {

void OnnxAsrModel::GetInputOutputInfo(
const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names, std::vector<const char*>* out_names) {
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
Ort::AllocatorWithDefaultOptions allocator;
// Input info
int num_nodes = session->GetInputCount();
in_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetInputName(i, allocator);
Ort::AllocatedStringPtr in_name_ptr = session->GetInputNameAllocated(i, allocator);
Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
Expand All @@ -50,15 +50,15 @@ void OnnxAsrModel::GetInputOutputInfo(
shape << j;
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type
LOG(INFO) << "\tInput " << i << " : name=" << in_name_ptr.get() << " type=" << type
<< " dims=" << shape.str();
(*in_names)[i] = name;
(*in_names)[i] = std::string(in_name_ptr.get());;
}
// Output info
num_nodes = session->GetOutputCount();
out_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetOutputName(i, allocator);
Ort::AllocatedStringPtr out_name_ptr= session->GetOutputNameAllocated(i, allocator);
Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
Expand All @@ -68,9 +68,9 @@ void OnnxAsrModel::GetInputOutputInfo(
shape << j;
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type
LOG(INFO) << "\tOutput " << i << " : name=" << out_name_ptr.get() << " type=" << type
<< " dims=" << shape.str();
(*out_names)[i] = name;
(*out_names)[i] = std::string(out_name_ptr.get());
}
}

Expand Down Expand Up @@ -106,24 +106,24 @@ void OnnxAsrModel::Read(const std::string& model_dir) {

Ort::AllocatorWithDefaultOptions allocator;
encoder_output_size_ =
atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator));
atoi(model_metadata.LookupCustomMetadataMapAllocated("output_size", allocator).get());
num_blocks_ =
atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator));
head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator));
atoi(model_metadata.LookupCustomMetadataMapAllocated("num_blocks", allocator).get());
head_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("head", allocator).get());
cnn_module_kernel_ = atoi(
model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator));
model_metadata.LookupCustomMetadataMapAllocated("cnn_module_kernel", allocator).get());
subsampling_rate_ = atoi(
model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator));
model_metadata.LookupCustomMetadataMapAllocated("subsampling_rate", allocator).get());
right_context_ =
atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator));
sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator));
eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator));
is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap(
"is_bidirectional_decoder", allocator));
atoi(model_metadata.LookupCustomMetadataMapAllocated("right_context", allocator).get());
sos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("sos_symbol", allocator).get());
eos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("eos_symbol", allocator).get());
is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMapAllocated(
"is_bidirectional_decoder", allocator).get());
chunk_size_ =
atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator));
atoi(model_metadata.LookupCustomMetadataMapAllocated("chunk_size", allocator).get());
num_left_chunks_ =
atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator));
atoi(model_metadata.LookupCustomMetadataMapAllocated("left_chunks", allocator).get());

LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
Expand Down Expand Up @@ -264,24 +264,32 @@ void OnnxAsrModel::ForwardEncoderFunc(
// 2. Encoder chunk forward
std::vector<Ort::Value> inputs;
for (auto name : encoder_in_names_) {
if (!strcmp(name, "chunk")) {
if (!strcmp(name.c_str(), "chunk")) {
inputs.emplace_back(std::move(feats_ort));
} else if (!strcmp(name, "offset")) {
} else if (!strcmp(name.c_str(), "offset")) {
inputs.emplace_back(std::move(offset_ort));
} else if (!strcmp(name, "required_cache_size")) {
} else if (!strcmp(name.c_str(), "required_cache_size")) {
inputs.emplace_back(std::move(required_cache_size_ort));
} else if (!strcmp(name, "att_cache")) {
} else if (!strcmp(name.c_str(), "att_cache")) {
inputs.emplace_back(std::move(att_cache_ort_));
} else if (!strcmp(name, "cnn_cache")) {
} else if (!strcmp(name.c_str(), "cnn_cache")) {
inputs.emplace_back(std::move(cnn_cache_ort_));
} else if (!strcmp(name, "att_mask")) {
} else if (!strcmp(name.c_str(), "att_mask")) {
inputs.emplace_back(std::move(att_mask_ort));
}
}

// Convert std::vector<std::string> to std::vector<const char*> for using C-style strings
std::vector<const char*> encoder_in_names(encoder_in_names_.size());
std::vector<const char*> encoder_out_names(encoder_out_names_.size());
std::transform(encoder_in_names_.begin(), encoder_in_names_.end(), encoder_in_names.begin(),
[](const std::string& name) { return name.c_str(); });
std::transform(encoder_out_names_.begin(), encoder_out_names_.end(), encoder_out_names.begin(),
[](const std::string& name) { return name.c_str(); });

std::vector<Ort::Value> ort_outputs = encoder_session_->Run(
Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(),
inputs.size(), encoder_out_names_.data(), encoder_out_names_.size());
Ort::RunOptions{nullptr}, encoder_in_names.data(), inputs.data(),
inputs.size(), encoder_out_names.data(), encoder_out_names.size());

offset_ += static_cast<int>(
ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[1]);
Expand All @@ -291,9 +299,17 @@ void OnnxAsrModel::ForwardEncoderFunc(
std::vector<Ort::Value> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));

// Convert std::vector<std::string> to std::vector<const char*> for using C-style strings
std::vector<const char*> ctc_in_names(ctc_in_names_.size());
std::vector<const char*> ctc_out_names(ctc_out_names_.size());
std::transform(ctc_in_names_.begin(), ctc_in_names_.end(), ctc_in_names.begin(),
[](const std::string& name) { return name.c_str(); });
std::transform(ctc_out_names_.begin(), ctc_out_names_.end(), ctc_out_names.begin(),
[](const std::string& name) { return name.c_str(); });

std::vector<Ort::Value> ctc_ort_outputs = ctc_session_->Run(
Ort::RunOptions{nullptr}, ctc_in_names_.data(), ctc_inputs.data(),
ctc_inputs.size(), ctc_out_names_.data(), ctc_out_names_.size());
Ort::RunOptions{nullptr}, ctc_in_names.data(), ctc_inputs.data(),
ctc_inputs.size(), ctc_out_names.data(), ctc_out_names.size());
encoder_outs_.push_back(std::move(ctc_inputs[0]));

float* logp_data = ctc_ort_outputs[0].GetTensorMutableData<float>();
Expand Down Expand Up @@ -393,10 +409,18 @@ void OnnxAsrModel::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
rescore_inputs.emplace_back(std::move(hyps_lens_tensor_));
rescore_inputs.emplace_back(std::move(decode_input_tensor_));

// Convert std::vector<std::string> to std::vector<const char*> for using C-style strings
std::vector<const char*> rescore_in_names(rescore_in_names_.size());
std::vector<const char*> rescore_out_names(rescore_out_names_.size());
std::transform(rescore_in_names_.begin(), rescore_in_names_.end(), rescore_in_names.begin(),
[](const std::string& name) { return name.c_str(); });
std::transform(rescore_out_names_.begin(), rescore_out_names_.end(), rescore_out_names.begin(),
[](const std::string& name) { return name.c_str(); });

std::vector<Ort::Value> rescore_outputs = rescore_session_->Run(
Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(),
rescore_inputs.size(), rescore_out_names_.data(),
rescore_out_names_.size());
Ort::RunOptions{nullptr}, rescore_in_names.data(), rescore_inputs.data(),
rescore_inputs.size(), rescore_out_names.data(),
rescore_out_names.size());

float* decoder_outs_data = rescore_outputs[0].GetTensorMutableData<float>();
float* r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData<float>();
Expand Down
10 changes: 5 additions & 5 deletions runtime/core/decoder/onnx_asr_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class OnnxAsrModel : public AsrModel {
std::vector<float>* rescoring_score) override;
std::shared_ptr<AsrModel> Copy() const override;
void GetInputOutputInfo(const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names,
std::vector<const char*>* out_names);
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);

protected:
void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
Expand All @@ -70,9 +70,9 @@ class OnnxAsrModel : public AsrModel {
std::shared_ptr<Ort::Session> ctc_session_ = nullptr;

// node names
std::vector<const char*> encoder_in_names_, encoder_out_names_;
std::vector<const char*> ctc_in_names_, ctc_out_names_;
std::vector<const char*> rescore_in_names_, rescore_out_names_;
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;

// caches
Ort::Value att_cache_ort_{nullptr};
Expand Down

0 comments on commit 5cafbee

Please sign in to comment.