diff --git a/runtime/core/cmake/onnx.cmake b/runtime/core/cmake/onnx.cmake index 919511d45..b86c3dcdc 100644 --- a/runtime/core/cmake/onnx.cmake +++ b/runtime/core/cmake/onnx.cmake @@ -1,23 +1,23 @@ if(ONNX) - set(ONNX_VERSION "1.12.0") + set(ONNX_VERSION "1.13.1") if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip") - set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176") + set(URL_HASH "SHA256=cd8318dc30352e0d615f809bd544bfd18b578289ec16621252b5db1994f09e43") elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac") + set(URL_HASH "SHA256=18e441585de69ef8aab263e2e96f0325729537ebfbd17cdcee78b2eabf0594d2") else() set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5") + set(URL_HASH "SHA256=2c7fdcfa8131b52167b1870747758cb24265952eba975318a67cc840c04ca73e") endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-arm64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=23117b6f5d7324d4a7c51184e5f808dd952aec411a6b99a1b6fd1011de06e300") + set(URL_HASH "SHA256=10ce30925c789715f29424a7658b41c601dfbde5d58fe21cb53ad418cde3c215") else() set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600") + set(URL_HASH "SHA256=32f3fff17b01db779e9e3cbe32f27adba40460e6202a79dfd1ac76b4f20588ef") endif() else() message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')") diff --git a/runtime/core/decoder/onnx_asr_model.cc b/runtime/core/decoder/onnx_asr_model.cc index fc7afc704..667798c90 100644 --- a/runtime/core/decoder/onnx_asr_model.cc +++ b/runtime/core/decoder/onnx_asr_model.cc @@ -34,13 +34,13 @@ void OnnxAsrModel::InitEngineThreads(int num_threads) { void OnnxAsrModel::GetInputOutputInfo( const std::shared_ptr& session, - std::vector* in_names, std::vector* out_names) { + std::vector* in_names, std::vector* out_names) { Ort::AllocatorWithDefaultOptions allocator; // Input info int num_nodes = session->GetInputCount(); in_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetInputName(i, allocator); + Ort::AllocatedStringPtr in_name_ptr = session->GetInputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetInputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -50,15 +50,15 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tInput " << i << " : name=" << in_name_ptr.get() << " type=" << type << " dims=" << shape.str(); - (*in_names)[i] = name; + (*in_names)[i] = std::string(in_name_ptr.get());; } // Output info num_nodes = session->GetOutputCount(); out_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetOutputName(i, allocator); + Ort::AllocatedStringPtr out_name_ptr= session->GetOutputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -68,9 +68,9 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tOutput " << i << " : name=" << out_name_ptr.get() << " type=" << type << " dims=" << shape.str(); - (*out_names)[i] = name; + (*out_names)[i] = std::string(out_name_ptr.get()); } } @@ -106,24 +106,24 @@ void OnnxAsrModel::Read(const std::string& model_dir) { Ort::AllocatorWithDefaultOptions allocator; encoder_output_size_ = - atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("output_size", allocator).get()); num_blocks_ = - atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator)); - head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("num_blocks", allocator).get()); + head_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("head", allocator).get()); cnn_module_kernel_ = atoi( - model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator)); + model_metadata.LookupCustomMetadataMapAllocated("cnn_module_kernel", allocator).get()); subsampling_rate_ = atoi( - model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator)); + model_metadata.LookupCustomMetadataMapAllocated("subsampling_rate", allocator).get()); right_context_ = - atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator)); - sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator)); - eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator)); - is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap( - "is_bidirectional_decoder", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("right_context", allocator).get()); + sos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("sos_symbol", allocator).get()); + eos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("eos_symbol", allocator).get()); + is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMapAllocated( + "is_bidirectional_decoder", allocator).get()); chunk_size_ = - atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("chunk_size", allocator).get()); num_left_chunks_ = - atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("left_chunks", allocator).get()); LOG(INFO) << "Onnx Model Info:"; LOG(INFO) << "\tencoder_output_size " << encoder_output_size_; @@ -264,24 +264,32 @@ void OnnxAsrModel::ForwardEncoderFunc( // 2. Encoder chunk forward std::vector inputs; for (auto name : encoder_in_names_) { - if (!strcmp(name, "chunk")) { + if (!strcmp(name.c_str(), "chunk")) { inputs.emplace_back(std::move(feats_ort)); - } else if (!strcmp(name, "offset")) { + } else if (!strcmp(name.c_str(), "offset")) { inputs.emplace_back(std::move(offset_ort)); - } else if (!strcmp(name, "required_cache_size")) { + } else if (!strcmp(name.c_str(), "required_cache_size")) { inputs.emplace_back(std::move(required_cache_size_ort)); - } else if (!strcmp(name, "att_cache")) { + } else if (!strcmp(name.c_str(), "att_cache")) { inputs.emplace_back(std::move(att_cache_ort_)); - } else if (!strcmp(name, "cnn_cache")) { + } else if (!strcmp(name.c_str(), "cnn_cache")) { inputs.emplace_back(std::move(cnn_cache_ort_)); - } else if (!strcmp(name, "att_mask")) { + } else if (!strcmp(name.c_str(), "att_mask")) { inputs.emplace_back(std::move(att_mask_ort)); } } + // Convert std::vector to std::vector for using C-style strings + std::vector encoder_in_names(encoder_in_names_.size()); + std::vector encoder_out_names(encoder_out_names_.size()); + std::transform(encoder_in_names_.begin(), encoder_in_names_.end(), encoder_in_names.begin(), + [](const std::string& name) { return name.c_str(); }); + std::transform(encoder_out_names_.begin(), encoder_out_names_.end(), encoder_out_names.begin(), + [](const std::string& name) { return name.c_str(); }); + std::vector ort_outputs = encoder_session_->Run( - Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), - inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); + Ort::RunOptions{nullptr}, encoder_in_names.data(), inputs.data(), + inputs.size(), encoder_out_names.data(), encoder_out_names.size()); offset_ += static_cast( ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[1]); @@ -291,9 +299,17 @@ void OnnxAsrModel::ForwardEncoderFunc( std::vector ctc_inputs; ctc_inputs.emplace_back(std::move(ort_outputs[0])); + // Convert std::vector to std::vector for using C-style strings + std::vector ctc_in_names(ctc_in_names_.size()); + std::vector ctc_out_names(ctc_out_names_.size()); + std::transform(ctc_in_names_.begin(), ctc_in_names_.end(), ctc_in_names.begin(), + [](const std::string& name) { return name.c_str(); }); + std::transform(ctc_out_names_.begin(), ctc_out_names_.end(), ctc_out_names.begin(), + [](const std::string& name) { return name.c_str(); }); + std::vector ctc_ort_outputs = ctc_session_->Run( - Ort::RunOptions{nullptr}, ctc_in_names_.data(), ctc_inputs.data(), - ctc_inputs.size(), ctc_out_names_.data(), ctc_out_names_.size()); + Ort::RunOptions{nullptr}, ctc_in_names.data(), ctc_inputs.data(), + ctc_inputs.size(), ctc_out_names.data(), ctc_out_names.size()); encoder_outs_.push_back(std::move(ctc_inputs[0])); float* logp_data = ctc_ort_outputs[0].GetTensorMutableData(); @@ -393,10 +409,18 @@ void OnnxAsrModel::AttentionRescoring(const std::vector>& hyps, rescore_inputs.emplace_back(std::move(hyps_lens_tensor_)); rescore_inputs.emplace_back(std::move(decode_input_tensor_)); + // Convert std::vector to std::vector for using C-style strings + std::vector rescore_in_names(rescore_in_names_.size()); + std::vector rescore_out_names(rescore_out_names_.size()); + std::transform(rescore_in_names_.begin(), rescore_in_names_.end(), rescore_in_names.begin(), + [](const std::string& name) { return name.c_str(); }); + std::transform(rescore_out_names_.begin(), rescore_out_names_.end(), rescore_out_names.begin(), + [](const std::string& name) { return name.c_str(); }); + std::vector rescore_outputs = rescore_session_->Run( - Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(), - rescore_inputs.size(), rescore_out_names_.data(), - rescore_out_names_.size()); + Ort::RunOptions{nullptr}, rescore_in_names.data(), rescore_inputs.data(), + rescore_inputs.size(), rescore_out_names.data(), + rescore_out_names.size()); float* decoder_outs_data = rescore_outputs[0].GetTensorMutableData(); float* r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData(); diff --git a/runtime/core/decoder/onnx_asr_model.h b/runtime/core/decoder/onnx_asr_model.h index f5d9e9a0c..30c82df40 100644 --- a/runtime/core/decoder/onnx_asr_model.h +++ b/runtime/core/decoder/onnx_asr_model.h @@ -44,8 +44,8 @@ class OnnxAsrModel : public AsrModel { std::vector* rescoring_score) override; std::shared_ptr Copy() const override; void GetInputOutputInfo(const std::shared_ptr& session, - std::vector* in_names, - std::vector* out_names); + std::vector* in_names, + std::vector* out_names); protected: void ForwardEncoderFunc(const std::vector>& chunk_feats, @@ -70,9 +70,9 @@ class OnnxAsrModel : public AsrModel { std::shared_ptr ctc_session_ = nullptr; // node names - std::vector encoder_in_names_, encoder_out_names_; - std::vector ctc_in_names_, ctc_out_names_; - std::vector rescore_in_names_, rescore_out_names_; + std::vector encoder_in_names_, encoder_out_names_; + std::vector ctc_in_names_, ctc_out_names_; + std::vector rescore_in_names_, rescore_out_names_; // caches Ort::Value att_cache_ort_{nullptr};