Skip to content

Commit

Permalink
fix trt problem (PaddlePaddle#35938)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored and AnnaTrainingG committed Sep 29, 2021
1 parent 87049f3 commit c759259
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cmake/external/lite.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)

if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204)
set(LITE_GIT_TAG 4ab64daecc11fbf74fffdc6a4733f388472e7d5d)
endif()

if(NOT CUDA_ARCH_NAME)
Expand Down
21 changes: 18 additions & 3 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads.
#if PADDLE_WITH_TENSORRT
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteAll();
auto &block = prog->Block(0);
for (auto &op_desc : block.AllOps()) {
if (op_desc->Type() == "tensorrt_engine") {
std::string engine_key =
BOOST_GET_CONST(std::string, op_desc->GetAttr("engine_key"));
int engine_predictor_id =
BOOST_GET_CONST(int, op_desc->GetAttr("predictor_id"));
std::string engine_name =
engine_key + std::to_string(engine_predictor_id);
if (paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Has(engine_name)) {
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteKey(engine_name);
}
}
}
#endif
delete prog;
});
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,14 @@ class TRTEngineManager {
}
}

void DeleteKey(const std::string& key) {
auto iter = engines_.find(key);
if (iter != engines_.end()) {
iter->second.reset(nullptr);
engines_.erase(iter);
}
}

private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
};
Expand Down

0 comments on commit c759259

Please sign in to comment.