-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference_trt.h
57 lines (48 loc) · 1.2 KB
/
inference_trt.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#pragma once
#include "inference_base.h"
using namespace std;
class InferenceTensorRT : public InferenceBase
{
private:
IRuntime *runtime;
ICudaEngine *engine;
IExecutionContext *context;
bool isInt8;
//batch size
const int N = 1;
const float visualizeThreshold = 0.5;
vector<string> labelsVector;
vector<int> numDetections;
vector<float> detections;
string outFileRoot;
protected:
int ReadGraph() override;
int ReadClassLabels() override;
int doInference(cv::cuda::GpuMat &d_frame) override;
void visualize(cv::cuda::GpuMat&, double) override;
public:
InferenceTensorRT(const string &labelsFile, const string &graphFile, bool isInt8, double threshScore = 0.5, double threshIOU = 0.8, int dbg = 0, string outFile="")
: InferenceBase(labelsFile, graphFile, threshScore, threshIOU, dbg)
, labelsVector()
, numDetections(N)
, detections(N * detectionOutputParam.keepTopK * 7)
, outFileRoot(outFile)
, isInt8(isInt8)
{
}
virtual ~InferenceTensorRT()
{
if(context != nullptr)
{
context->destroy();
}
if(engine != nullptr)
{
engine->destroy();
}
if(runtime != nullptr)
{
runtime->destroy();
}
}
};