Skip to content

Commit

Permalink
feat: add yolox onnx export and trt support
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Nov 26, 2021
1 parent 1e3ff16 commit 80b7e6a
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 41 deletions.
67 changes: 52 additions & 15 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "tensorrtlib.h"
#include "utils/apitools.h"
#include "tensorrtinputconns.h"
#include "utils/apitools.h"
#include "NvInferPlugin.h"
#include "../parsers/onnx/NvOnnxParser.h"
#include "protoUtils.h"
Expand All @@ -32,6 +31,7 @@
#ifdef USE_CUDA_CV
#include <opencv2/core/cuda_stream_accessor.hpp>
#endif
#include "utils/bbox.hpp"

namespace dd
{
Expand Down Expand Up @@ -99,12 +99,15 @@ namespace dd
_engineFileName = tl._engineFileName;
_readEngine = tl._readEngine;
_writeEngine = tl._writeEngine;
_arch = tl._arch;
_gpuid = tl._gpuid;
_TRTContextReady = tl._TRTContextReady;
_buffers = tl._buffers;
_bbox = tl._bbox;
_ctc = tl._ctc;
_timeserie = tl._timeserie;
_regression = tl._regression;
_need_nms = tl._need_nms;
_inputIndex = tl._inputIndex;
_outputIndex0 = tl._outputIndex0;
_outputIndex1 = tl._outputIndex1;
Expand Down Expand Up @@ -196,6 +199,24 @@ namespace dd
+ this->_mlmodel._repo);
}

// XXX(louis): this default value should be moved out of trt lib when
// init_mllib will be changed to DTOs
_top_k = ad.has("topk") ? ad.get("topk").get<int>() : 200;

if (ad.has("template"))
{
std::string tmplate = ad.get("template").get<std::string>();
this->_logger->info("Model template is {}", tmplate);

if (tmplate == "yolox")
{
this->_mltype = "detection";
_need_nms = true;
}
else
throw MLLibBadParamException("Unknown template " + tmplate);
}

_builder = std::shared_ptr<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(trtLogger));
_builderc = std::shared_ptr<nvinfer1::IBuilderConfig>(
Expand Down Expand Up @@ -408,6 +429,10 @@ namespace dd
nvinfer1::IHostMemory *n
= _builder->buildSerializedNetwork(*network, *_builderc);

if (n == nullptr)
throw MLLibInternalException("Could not build model: "
+ this->_mlmodel._model);

return _runtime->deserializeCudaEngine(n->data(), n->size());
}

Expand Down Expand Up @@ -521,7 +546,7 @@ namespace dd
this->_logger->info("found {} classes", _nclasses);
}

if (_bbox)
if (_bbox && !this->_mlmodel._def.empty())
_top_k = findTopK(this->_mlmodel._def);

if (_nclasses <= 0)
Expand Down Expand Up @@ -862,9 +887,10 @@ namespace dd
}
bool leave = false;
int curi = -1;

while (true && k < results_height)
{
if (output_params->best_bbox > 0
if (!_need_nms && output_params->best_bbox > 0
&& bboxes.size() >= static_cast<size_t>(
output_params->best_bbox))
break;
Expand All @@ -884,28 +910,39 @@ namespace dd
break; // this belongs to next image
++k;
outr += det_size;

if (detection[2] < output_params->confidence_threshold)
continue;

// Fix border of bboxes
detection[3] = std::max(((float)detection[3]), 0.0f);
detection[4] = std::max(((float)detection[4]), 0.0f);
detection[5] = std::min(((float)detection[5]), 1.0f);
detection[6] = std::min(((float)detection[6]), 1.0f);
detection[3]
= std::max(((float)detection[3]), 0.0f) * (cols - 1);
detection[4]
= std::max(((float)detection[4]), 0.0f) * (rows - 1);
detection[5]
= std::min(((float)detection[5]), 1.0f) * (cols - 1);
detection[6]
= std::min(((float)detection[6]), 1.0f) * (rows - 1);

probs.push_back(detection[2]);
cats.push_back(this->_mlmodel.get_hcorresp(detection[1]));
APIData ad_bbox;
ad_bbox.add("xmin", static_cast<double>(detection[3]
* (cols - 1)));
ad_bbox.add("ymin", static_cast<double>(detection[4]
* (rows - 1)));
ad_bbox.add("xmax", static_cast<double>(detection[5]
* (cols - 1)));
ad_bbox.add("ymax", static_cast<double>(detection[6]
* (rows - 1)));
ad_bbox.add("xmin", static_cast<double>(detection[3]));
ad_bbox.add("ymin", static_cast<double>(detection[4]));
ad_bbox.add("xmax", static_cast<double>(detection[5]));
ad_bbox.add("ymax", static_cast<double>(detection[6]));
bboxes.push_back(ad_bbox);
}

if (_need_nms)
{
// We assume that bboxes are already sorted in model output
bbox_utils::nms_sorted_bboxes(
bboxes, probs, cats,
(double)output_params->nms_threshold,
(int)output_params->best_bbox);
}

if (leave)
continue;
rad.add("uri", uri);
Expand Down
3 changes: 3 additions & 0 deletions src/backends/tensorrt/tensorrtlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ namespace dd
bool _regression = false;
bool _timeserie = false;

// detection
bool _need_nms = false;

std::vector<void *> _buffers;

bool _TRTContextReady = false;
Expand Down
1 change: 1 addition & 0 deletions src/dto/output_connector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace dd
DTO_FIELD(Float32, confidence_threshold) = 0.0;
DTO_FIELD(Int32, best);
DTO_FIELD(Int32, best_bbox) = -1;
DTO_FIELD(Float32, nms_threshold) = 0.45;
DTO_FIELD(Vector<String>, confidences);

DTO_FIELD_INFO(image)
Expand Down
2 changes: 0 additions & 2 deletions src/imginputfileconn.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,6 @@ namespace dd
_height = params->crop_height;
}

// XXX(louis) We cannot set these parameters to false if they are already
// true
if (params->bw != nullptr)
_bw = params->bw;
if (params->rgb != nullptr)
Expand Down
92 changes: 82 additions & 10 deletions src/utils/bbox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ namespace dd
{
namespace bbox_utils
{
double area(const std::vector<double> &bbox)
template <typename T> inline T area(const std::vector<T> &bbox)
{
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]);
}

std::vector<double> intersect(const std::vector<double> &bbox1,
const std::vector<double> &bbox2)
template <typename T>
inline std::vector<T> intersect(const std::vector<T> &bbox1,
const std::vector<T> &bbox2)
{
std::vector<double> inter{
std::vector<T> inter{
std::max(bbox1[0], bbox2[0]),
std::max(bbox1[1], bbox2[1]),
std::min(bbox1[2], bbox2[2]),
Expand All @@ -45,21 +46,92 @@ namespace dd
// if xmin > xmax or ymin > ymax, intersection is empty
if (inter[0] >= inter[2] || inter[1] >= inter[3])
{
return { 0., 0., 0., 0. };
return { T(0), T(0), T(0), T(0) };
}
else
return inter;
}

double iou(const std::vector<double> &bbox1,
const std::vector<double> &bbox2)
template <typename T>
inline T iou(const std::vector<T> &bbox1, const std::vector<T> &bbox2)
{
double a1 = area(bbox1);
double a2 = area(bbox2);
auto a1 = area(bbox1);
auto a2 = area(bbox2);
auto inter = intersect(bbox1, bbox2);
double ainter = area(inter);
auto ainter = area(inter);
return ainter / (a1 + a2 - ainter);
}

/** bboxes: list of bboxes in the format { xmin, ymin, xmax, ymax } sorted
* by decreasing confidence
*
* picked: vector used as output containing indices of bboxes kept by nms.
*/
template <typename T>
inline void nms_sorted_bboxes(const std::vector<std::vector<T>> &bboxes,
std::vector<size_t> &picked, T nms_threshold)
{
picked.clear();
const size_t n = bboxes.size();

for (size_t i = 0; i < n; i++)
{
const std::vector<T> &bbox_a = bboxes[i];

bool keep = true;
for (size_t j = 0; j < picked.size(); j++)
{
const std::vector<T> &bbox_b = bboxes[picked[j]];

// intersection over union
auto iou = bbox_utils::iou(bbox_a, bbox_b);
if (iou > nms_threshold)
keep = false;
}

if (keep)
picked.push_back(i);
}
}

inline void nms_sorted_bboxes(std::vector<APIData> &bboxes,
std::vector<double> &probs,
std::vector<std::string> &cats,
double nms_threshold, int best_bbox)
{
std::vector<std::vector<double>> sorted_boxes;
std::vector<size_t> picked;

for (size_t l = 0; l < bboxes.size(); ++l)
{
std::vector<double> bbox_vec{ bboxes[l].get("xmin").get<double>(),
bboxes[l].get("ymin").get<double>(),
bboxes[l].get("xmax").get<double>(),
bboxes[l].get("ymax").get<double>() };
sorted_boxes.push_back(bbox_vec);
}
// We assume that bboxes are already sorted in model output

bbox_utils::nms_sorted_bboxes(sorted_boxes, picked, nms_threshold);
std::vector<APIData> nbboxes;
std::vector<double> nprobs;
std::vector<std::string> ncats;

for (size_t pick : picked)
{
nbboxes.push_back(bboxes.at(pick));
nprobs.push_back(probs.at(pick));
ncats.push_back(cats.at(pick));

if (best_bbox > 0
&& nbboxes.size() >= static_cast<size_t>(best_bbox))
break;
}

bboxes = nbboxes;
probs = nprobs;
cats = ncats;
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ if (USE_TENSORRT)
"resnet_onnx_trt.tar.gz"
"resnet_onnx_trt"
)
DOWNLOAD_DATASET(
"ONNX yolox model"
"https://deepdetect.com/models/init/desktop/images/detection/yolox_onnx_trt.tar.gz"
"examples/trt"
"yolox_onnx_trt.tar.gz"
"yolox_onnx_trt"
)
DOWNLOAD_DATASET(
"ONNX CycleGAN model"
"https://deepdetect.com/dd/examples/tensorrt/cyclegan_resnet_attn_onnx_trt.tar.gz"
Expand Down
52 changes: 52 additions & 0 deletions tests/ut-tensorrtapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static std::string squeez_repo = "../examples/trt/squeezenet_ssd_trt/";
static std::string refinedet_repo = "../examples/trt/faces_512/";
static std::string squeezv1_repo = "../examples/trt/squeezenet_v1/";
static std::string resnet_onnx_repo = "../examples/trt/resnet_onnx_trt/";
static std::string yolox_onnx_repo = "../examples/trt/yolox_onnx_trt/";
static std::string cyclegan_onnx_repo
= "../examples/trt/cyclegan_resnet_attn_onnx_trt/";

Expand Down Expand Up @@ -244,6 +245,57 @@ TEST(tensorrtapi, service_predict_onnx)
> 0.3);
}

TEST(tensorrtapi, service_predict_bbox_onnx)
{
// create service
JsonAPI japi;
std::string sname = "onnx";
std::string jstr
= "{\"mllib\":\"tensorrt\",\"description\":\"Test onnx "
"import\",\"type\":\"supervised\",\"model\":{\"repository\":\""
+ yolox_onnx_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"640,\"width\":640,\"rgb\":true},\"mllib\":{\"template\":\"yolox\","
"\"maxBatchSize\":1,\"maxWorkspaceSize\":256,\"gpuid\":0,"
"\"nclasses\":80}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// predict
std::string jpredictstr
= "{\"service\":\"" + sname
+ "\",\"parameters\":{\"input\":{},\"output\":{\"bbox\":true,"
"\"confidence_threshold\":0.8}},\"data\":[\""
+ resnet_onnx_repo + "cat.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());

auto &preds = jd["body"]["predictions"][0]["classes"];
ASSERT_EQ(preds.Size(), 1);
std::string cl1 = preds[0]["cat"].GetString();
ASSERT_TRUE(cl1 == "14");
ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.85);
auto &bbox = preds[0]["bbox"];
ASSERT_TRUE(bbox["xmin"].GetDouble() < 50 && bbox["xmax"].GetDouble() > 200
&& bbox["ymin"].GetDouble() < 50
&& bbox["ymax"].GetDouble() > 200);
// Check confidence threshold
ASSERT_TRUE(preds[preds.Size() - 1]["prob"].GetDouble() >= 0.8);

ASSERT_TRUE(fileops::file_exists(yolox_onnx_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs1"));
jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(yolox_onnx_repo + "TRTengine_arch"
+ get_trt_archi() + "_bs1"));
}

TEST(tensorrtapi, service_predict_gan_onnx)
{
// create service
Expand Down
Loading

0 comments on commit 80b7e6a

Please sign in to comment.