Skip to content

Commit

Permalink
feat(trt): add return cv::Mat instead of vector for GAN output
Browse files Browse the repository at this point in the history
This feature is only available when linking with dede and does improve performances a lot especially with chains
  • Loading branch information
Bycob authored and beniz committed Aug 9, 2021
1 parent 3093439 commit 4990e7b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 12 deletions.
67 changes: 58 additions & 9 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,6 @@ namespace dd
_gpuid = predict_dto->parameters->mllib->gpuid->_ids[0];
cudaSetDevice(_gpuid);

// detect architecture
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, _gpuid);
std::string arch = std::to_string(prop.major) + std::to_string(prop.minor);
if (_first_predict)
this->_logger->info("GPU {} architecture = compute_{}", _gpuid, arch);

auto output_params = predict_dto->parameters->output;

std::string out_blob = "prob";
Expand All @@ -443,6 +436,13 @@ namespace dd
= predict_dto->parameters->mllib->extract_layer->std_str();
}

// detect architecture
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, _gpuid);
std::string arch = std::to_string(prop.major) + std::to_string(prop.minor);
if (_first_predict)
this->_logger->info("GPU {} architecture = compute_{}", _gpuid, arch);

TInputConnectorStrategy inputc(this->_inputc);

if (!_TRTContextReady)
Expand Down Expand Up @@ -907,8 +907,57 @@ namespace dd
else
rad.add("uri", std::to_string(idoffset + j));
rad.add("loss", 0.0);
std::vector<double> vals(_floatOut.begin(), _floatOut.end());
rad.add("vals", vals);
if (output_params->image)
{
size_t img_chan = size_t(_dims.d[1]);
size_t img_width = size_t(_dims.d[2]),
img_height = size_t(_dims.d[3]);
auto cv_type = img_chan == 3 ? CV_8UC3 : CV_8UC1;
cv::Mat vals_mat(img_width, img_height, cv_type);

size_t chan_offset = img_width * img_height;

for (size_t y = 0; y < img_height; ++y)
{
for (size_t x = 0; x < img_width; ++x)
{
if (cv_type == CV_8UC3)
{
vals_mat.at<cv::Vec3b>(y, x) = cv::Vec3b(
static_cast<int8_t>(
(_floatOut[2 * chan_offset
+ y * img_width + x]
+ 1)
* 255.0 / 2.0),
static_cast<int8_t>(
(_floatOut[1 * chan_offset
+ y * img_width + x]
+ 1)
* 255.0 / 2.0),
static_cast<int8_t>(
(_floatOut[0 * chan_offset
+ y * img_width + x]
+ 1)
* 255.0 / 2.0));
}
else
{
vals_mat.at<int8_t>(y, x)
= static_cast<int8_t>(
(_floatOut[y * img_width + x] + 1)
* 255.0 / 2.0);
}
}
}

rad.add("vals", std::vector<cv::Mat>{ vals_mat });
}
else
{
std::vector<double> vals(_floatOut.begin(),
_floatOut.end());
rad.add("vals", vals);
}
vrad.push_back(rad);
}
}
Expand Down
7 changes: 6 additions & 1 deletion src/chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@ namespace dd
{
APIData vout;
APIData vals;
vals.add("vals", p.get("vals").get<std::vector<double>>());
if (p.get("vals").is<std::vector<cv::Mat>>())
vals.add("vals",
p.get("vals").get<std::vector<cv::Mat>>());
else
vals.add("vals",
p.get("vals").get<std::vector<double>>());
if (p.has("nns"))
vals.add("nns", p.getv("nns"));
vout.add(model_name, vals);
Expand Down
7 changes: 7 additions & 0 deletions src/dto/output_connector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ namespace dd
DTO_FIELD(Int32, best);
DTO_FIELD(Int32, best_bbox) = -1;

DTO_FIELD_INFO(image)
{
info->description = "wether to convert result to a cv::Mat (e.g. for "
"GANs or segmentation)";
};
DTO_FIELD(Boolean, image) = false;

/* ncnn */
DTO_FIELD(Int32, blank_label) = -1;
DTO_FIELD(Boolean, index) = false;
Expand Down
17 changes: 15 additions & 2 deletions src/unsupervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace dd
std::vector<double> _vals;
std::vector<bool> _bvals;
std::string _str;
std::vector<cv::Mat> _images;
#ifdef USE_SIMSEARCH
bool _indexed = false;
std::multimap<double, std::string> _nns; /**< nearest neigbors. */
Expand Down Expand Up @@ -120,7 +121,12 @@ namespace dd
"unsupervised output needs mllib.extract_layer param");
return;
}
std::vector<double> vals = ad.get("vals").get<std::vector<double>>();

std::vector<double> vals;
if (ad.get("vals").is<std::vector<double>>())
{
vals = ad.get("vals").get<std::vector<double>>();
}
if ((hit = _vres.find(uri)) == _vres.end())
{
_vres.insert(std::pair<std::string, int>(uri, _vvres.size()));
Expand All @@ -135,6 +141,11 @@ namespace dd
else if (ad.has("meta_uri"))
meta_uri = ad.get("meta_uri").get<std::string>();
_vvres.push_back(unsup_result(uri, vals, extra, meta_uri));
if (ad.get("vals").is<std::vector<cv::Mat>>())
{
_vvres.back()._images
= ad.get("vals").get<std::vector<cv::Mat>>();
}
}
}
}
Expand Down Expand Up @@ -266,7 +277,9 @@ namespace dd
{
APIData adpred;
adpred.add("uri", _vvres.at(i)._uri);
if (_bool_binarized)
if (_vvres.at(i)._images.size() != 0)
adpred.add("vals", _vvres.at(i)._images);
else if (_bool_binarized)
adpred.add("vals", _vvres.at(i)._bvals);
else if (_string_binarized)
adpred.add("vals", _vvres.at(i)._str);
Expand Down

0 comments on commit 4990e7b

Please sign in to comment.