Skip to content

Commit 9458f21

Browse files
committed
feat(tests/util): added RunGraphEngineDynamic to handle dynamic input sized tensors
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com>
1 parent 98c797d commit 9458f21

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

tests/util/run_graph_engine.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include "core/conversion/conversion.h"
77
#include "cuda_runtime_api.h"
88

9+
#include <vector>
10+
#include <math.h>
11+
912
namespace trtorch {
1013
namespace tests {
1114
namespace util {
@@ -18,6 +21,34 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
1821
return std::move(a);
1922
}
2023

24+
std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten) {
25+
std::vector<core::conversion::InputRange> a;
26+
27+
for (auto i : ten) {
28+
auto opt = core::util::toVec(i.sizes());
29+
30+
std::vector<int64_t> min_range(opt);
31+
std::vector<int64_t> max_range(opt);
32+
33+
min_range[0] = ceil(opt[0]/2.0);
34+
max_range[0] = 2*opt[0];
35+
36+
// for (int64_t each : min_range) {
37+
// std::cout << each << std::endl;
38+
// }
39+
// for (int64_t each : opt) {
40+
// std::cout << each << std::endl;
41+
// }
42+
// for (int64_t each : max_range) {
43+
// std::cout << each << std::endl;
44+
// }
45+
46+
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
47+
}
48+
49+
return std::move(a);
50+
}
51+
2152
std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs) {
2253
auto rt = nvinfer1::createInferRuntime(core::util::logging::get_logger());
2354
auto engine = rt->deserializeCudaEngine(eng.c_str(), eng.size());
@@ -71,6 +102,17 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
71102
return RunEngine(eng, inputs);
72103
}
73104

105+
std::vector<at::Tensor> RunGraphEngineDynamic(std::shared_ptr<torch::jit::Graph>& g,
106+
core::conversion::GraphParams& named_params,
107+
std::vector<at::Tensor> inputs) {
108+
LOG_DEBUG("Running TRT version");
109+
auto in = toInputRangesDynamic(inputs);
110+
auto info = core::conversion::ConversionInfo(in);
111+
info.engine_settings.workspace_size = 1 << 20;
112+
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
113+
return RunEngine(eng, inputs);
114+
}
115+
74116
} // namespace util
75117
} // namespace tests
76118
} // namespace trtorch

tests/util/util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
2828
core::conversion::GraphParams& named_params,
2929
std::vector<at::Tensor> inputs);
3030

31+
// Runs an arbitrary JIT graph with dynamic input sizes by converting it to TensorRT
32+
// and running inference and returns results
33+
std::vector<at::Tensor> RunGraphEngineDynamic(std::shared_ptr<torch::jit::Graph>& g,
34+
core::conversion::GraphParams& named_params,
35+
std::vector<at::Tensor> inputs);
36+
3137
// Run the forward method of a module and return results
3238
torch::jit::IValue RunModuleForward(torch::jit::Module& mod,
3339
std::vector<torch::jit::IValue> inputs);

0 commit comments

Comments
 (0)