Skip to content

Commit

Permalink
fix(trt): detect architecture and rebuild model if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Jul 19, 2021
1 parent 5d34a39 commit 5c9ff89
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
31 changes: 27 additions & 4 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,20 @@ namespace dd

if (predict_dto->parameters->mllib->gpuid->_ids.size() == 0)
throw MLLibBadParamException("empty gpuid vector");
if (predict_dto->parameters->mllib->gpuid->_ids.size() > 1)
throw MLLibBadParamException(
"TensorRT: Multi-GPU inference is not applicable");

_gpuid = predict_dto->parameters->mllib->gpuid->_ids[0];
cudaSetDevice(_gpuid);

// detect architecture
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, _gpuid);
std::string arch = std::to_string(prop.major) + std::to_string(prop.minor);
if (_first_predict)
this->_logger->info("GPU {} architecture = compute_{}", _gpuid, arch);

auto output_params = predict_dto->parameters->output;

std::string out_blob = "prob";
Expand Down Expand Up @@ -500,7 +510,8 @@ namespace dd
+ std::to_string(bs));
}
std::ifstream file(this->_mlmodel._repo + "/" + _engineFileName
+ "_bs" + std::to_string(bs),
+ "_arch" + arch + "_bs"
+ std::to_string(bs),
std::ios::binary);
if (file.good())
{
Expand All @@ -517,8 +528,16 @@ namespace dd
_engine = std::shared_ptr<nvinfer1::ICudaEngine>(
runtime->deserializeCudaEngine(
trtModelStream.data(), trtModelStream.size(), nullptr),
[=](nvinfer1::ICudaEngine *e) { e->destroy(); });
[=](nvinfer1::ICudaEngine *e) {
if (e != nullptr)
e->destroy();
});
runtime->destroy();

if (_engine == nullptr)
throw MLLibInternalException(
"Engine could not be deserialized");

engineRead = true;
}
}
Expand Down Expand Up @@ -549,7 +568,8 @@ namespace dd
if (_writeEngine)
{
std::ofstream p(this->_mlmodel._repo + "/" + _engineFileName
+ "_bs" + std::to_string(_max_batch_size),
+ "_arch" + arch + "_bs"
+ std::to_string(_max_batch_size),
std::ios::binary);
nvinfer1::IHostMemory *trtModelStream = _engine->serialize();
p.write(reinterpret_cast<const char *>(trtModelStream->data()),
Expand All @@ -566,7 +586,10 @@ namespace dd

_context = std::shared_ptr<nvinfer1::IExecutionContext>(
_engine->createExecutionContext(),
[=](nvinfer1::IExecutionContext *e) { e->destroy(); });
[=](nvinfer1::IExecutionContext *e) {
if (e != nullptr)
e->destroy();
});
_TRTContextReady = true;

try
Expand Down
31 changes: 26 additions & 5 deletions tests/ut-tensorrtapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtest/gtest.h>
#include <stdio.h>
#include <iostream>
#include <cuda_runtime_api.h>

using namespace dd;

Expand All @@ -42,6 +43,13 @@ static std::string resnet_onnx_repo = "../examples/trt/resnet_onnx_trt/";
static std::string cyclegan_onnx_repo
= "../examples/trt/cyclegan_resnet_attn_onnx_trt/";

inline std::string get_trt_archi()
{
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
return std::to_string(prop.major) + std::to_string(prop.minor);
}

TEST(tensorrtapi, service_predict)
{
// create service
Expand Down Expand Up @@ -73,13 +81,16 @@ TEST(tensorrtapi, service_predict)
ASSERT_TRUE(cl1 == "15");
ASSERT_TRUE(jd["body"]["predictions"][0]["classes"][0]["prob"].GetDouble()
> 0.4);
ASSERT_TRUE(fileops::file_exists(squeez_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs48"));
// ASSERT_TRUE(!fileops::remove_file(squeez_repo, "net_tensorRT.proto"));
// ASSERT_TRUE(!fileops::remove_file(squeez_repo, "TRTengine_bs48"));
// ASSERT_TRUE(!fileops::remove_file(squeez_repo, "TRTengine_archXX_bs48"));
jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(squeez_repo + "net_tensorRT.proto"));
ASSERT_TRUE(!fileops::file_exists(squeez_repo + "TRTengine_bs48"));
ASSERT_TRUE(!fileops::file_exists(squeez_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs48"));
}

TEST(tensorrtapi, service_predict_best)
Expand Down Expand Up @@ -114,13 +125,16 @@ TEST(tensorrtapi, service_predict_best)
std::string age
= jd["body"]["predictions"][0]["classes"][0]["cat"].GetString();
ASSERT_TRUE(age == "29");
ASSERT_TRUE(fileops::file_exists(age_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs1"));
/*ASSERT_TRUE(!fileops::remove_file(age_repo, "net_tensorRT.proto"));
ASSERT_TRUE(!fileops::remove_file(age_repo, "TRTengine_bs_bs1"));*/
jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(age_repo + "net_tensorRT.proto"));
ASSERT_TRUE(!fileops::file_exists(age_repo + "TRTengine_bs_bs1"));
ASSERT_TRUE(!fileops::file_exists(age_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs1"));
}

TEST(tensorrtapi, service_predict_refinedet)
Expand Down Expand Up @@ -162,11 +176,14 @@ TEST(tensorrtapi, service_predict_refinedet)
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(400, jd["status"]["code"]);

ASSERT_TRUE(fileops::file_exists(squeez_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs48"));
jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(squeez_repo + "net_tensorRT.proto"));
ASSERT_TRUE(!fileops::file_exists(squeez_repo + "TRTengine_bs48"));
ASSERT_TRUE(!fileops::file_exists(squeez_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs48"));
}

TEST(tensorrtapi, service_predict_onnx)
Expand Down Expand Up @@ -242,8 +259,12 @@ TEST(tensorrtapi, service_predict_gan_onnx)
ASSERT_TRUE(jd["body"]["predictions"][0]["vals"].IsArray());
ASSERT_EQ(jd["body"]["predictions"][0]["vals"].Size(), 360 * 360 * 3);

ASSERT_TRUE(fileops::file_exists(cyclegan_onnx_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs1"));

jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(cyclegan_onnx_repo + "TRTengine_bs1"));
ASSERT_TRUE(!fileops::file_exists(cyclegan_onnx_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs1"));
}

0 comments on commit 5c9ff89

Please sign in to comment.