Skip to content

Commit

Permalink
code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Apr 23, 2021
1 parent cd8d640 commit 48c86d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
&engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &int8_enable_, &max_workspace_size_,
trt_node_name_with_precision, engine_cache_enable_, cache_path_, &runtime_, nullptr,
trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr,
allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_};
*state = p.release();
return 0;
Expand Down Expand Up @@ -1295,9 +1295,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
auto runtime = trt_state->runtime->get();
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
Expand Down Expand Up @@ -1326,8 +1325,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Deserialize engine
trt_state->context->reset();
trt_state->engine->reset();
auto runtime = trt_state->runtime->get();
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct TensorrtFuncState {
std::string trt_node_name_with_precision;
bool engine_cache_enable;
std::string engine_cache_path;
tensorrt_ptr::unique_pointer<nvinfer1::IRuntime>* runtime = nullptr;
nvinfer1::IRuntime* runtime = nullptr;

nvinfer1::IOptimizationProfile* trt_profile = nullptr;
AllocatorPtr scratch_allocator;
Expand Down

0 comments on commit 48c86d5

Please sign in to comment.