6
6
#include " core/conversion/conversion.h"
7
7
#include " cuda_runtime_api.h"
8
8
9
+ #include < vector>
10
+ #include < math.h>
11
+
9
12
namespace trtorch {
10
13
namespace tests {
11
14
namespace util {
@@ -18,6 +21,34 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
18
21
return std::move (a);
19
22
}
20
23
24
+ std::vector<core::conversion::InputRange> toInputRangesDynamic (std::vector<at::Tensor> ten) {
25
+ std::vector<core::conversion::InputRange> a;
26
+
27
+ for (auto i : ten) {
28
+ auto opt = core::util::toVec (i.sizes ());
29
+
30
+ std::vector<int64_t > min_range (opt);
31
+ std::vector<int64_t > max_range (opt);
32
+
33
+ min_range[0 ] = ceil (opt[0 ]/2.0 );
34
+ max_range[0 ] = 2 *opt[0 ];
35
+
36
+ // for (int64_t each : min_range) {
37
+ // std::cout << each << std::endl;
38
+ // }
39
+ // for (int64_t each : opt) {
40
+ // std::cout << each << std::endl;
41
+ // }
42
+ // for (int64_t each : max_range) {
43
+ // std::cout << each << std::endl;
44
+ // }
45
+
46
+ a.push_back (core::conversion::InputRange (min_range, opt, max_range));
47
+ }
48
+
49
+ return std::move (a);
50
+ }
51
+
21
52
std::vector<at::Tensor> RunEngine (std::string& eng, std::vector<at::Tensor> inputs) {
22
53
auto rt = nvinfer1::createInferRuntime (core::util::logging::get_logger ());
23
54
auto engine = rt->deserializeCudaEngine (eng.c_str (), eng.size ());
@@ -71,6 +102,17 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
71
102
return RunEngine (eng, inputs);
72
103
}
73
104
105
+ std::vector<at::Tensor> RunGraphEngineDynamic (std::shared_ptr<torch::jit::Graph>& g,
106
+ core::conversion::GraphParams& named_params,
107
+ std::vector<at::Tensor> inputs) {
108
+ LOG_DEBUG (" Running TRT version" );
109
+ auto in = toInputRangesDynamic (inputs);
110
+ auto info = core::conversion::ConversionInfo (in);
111
+ info.engine_settings .workspace_size = 1 << 20 ;
112
+ std::string eng = core::conversion::ConvertBlockToEngine (g->block (), info, named_params);
113
+ return RunEngine (eng, inputs);
114
+ }
115
+
74
116
} // namespace util
75
117
} // namespace tests
76
118
} // namespace trtorch
0 commit comments