@@ -56,54 +56,59 @@ int main(int argc, const char* argv[]) {
56
56
}
57
57
58
58
auto compile_spec = trtorch::CompileSpec (dims);
59
+ // compile_spec.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
59
60
compile_spec.workspace_size = 1 << 24 ;
60
-
61
- std::cout << " Checking operator support" << std::endl;
62
- if (!trtorch::CheckMethodOperatorSupport (mod, " forward" )) {
63
- std::cerr << " Method is not currently supported by TRTorch" << std::endl;
64
- return -1 ;
65
- }
66
-
67
- std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
61
+ compile_spec.op_precision = torch::kChar ;
62
+ // compile_spec.input_dtypes = {torch::kInt32, torch::kInt32};
63
+ // std::cout << "===Compile Spec: " << compile_spec << std::endl;
64
+ // compile_spec.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
65
+ // compile_spec.torch_fallback.min_block_size = 1;
66
+ // std::cout << "Checking operator support" << std::endl;
67
+ // if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
68
+ // std::cerr << "Method is not currently supported by TRTorch" << std::endl;
69
+ // return -1;
70
+ // }
71
+ //
72
+ // std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
68
73
auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , compile_spec);
69
74
std::ofstream out (" /tmp/engine_converted_from_jit.trt" );
70
75
out << engine;
71
76
out.close ();
72
77
73
- std::vector<torch::jit::IValue> jit_inputs_ivalues;
74
- std::vector<torch::jit::IValue> trt_inputs_ivalues;
75
- auto in = at::randint (5 , dims[0 ], {at::kCUDA });
76
- jit_inputs_ivalues.push_back (in.clone ());
77
- trt_inputs_ivalues.push_back (in.clone ());
78
-
79
- torch::jit::IValue jit_results_ivalues = mod.forward (jit_inputs_ivalues);
80
- std::vector<at::Tensor> jit_results;
81
- if (jit_results_ivalues.isTensor ()) {
82
- jit_results.push_back (jit_results_ivalues.toTensor ());
83
- } else {
84
- auto results = jit_results_ivalues.toTuple ()->elements ();
85
- for (auto r : results) {
86
- jit_results.push_back (r.toTensor ());
87
- }
88
- }
78
+ // std::vector<torch::jit::IValue> jit_inputs_ivalues;
79
+ // std::vector<torch::jit::IValue> trt_inputs_ivalues;
80
+ // auto in = at::randint(5, dims[0], {at::kCUDA});
81
+ // jit_inputs_ivalues.push_back(in.clone());
82
+ // trt_inputs_ivalues.push_back(in.clone());
83
+ // //
84
+ // torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
85
+ // std::vector<at::Tensor> jit_results;
86
+ // if (jit_results_ivalues.isTensor()) {
87
+ // jit_results.push_back(jit_results_ivalues.toTensor());
88
+ // } else {
89
+ // auto results = jit_results_ivalues.toTuple()->elements();
90
+ // for (auto r : results) {
91
+ // jit_results.push_back(r.toTensor());
92
+ // }
93
+ // }
89
94
90
95
std::cout << " Compiling graph as module" << std::endl;
91
96
auto trt_mod = trtorch::CompileGraph (mod, compile_spec);
92
- std::cout << " Running TRT module" << std::endl;
93
- torch::jit::IValue trt_results_ivalues = trt_mod.forward (trt_inputs_ivalues);
94
- std::vector<at::Tensor> trt_results;
95
- if (trt_results_ivalues.isTensor ()) {
96
- trt_results.push_back (trt_results_ivalues.toTensor ());
97
- } else {
98
- auto results = trt_results_ivalues.toTuple ()->elements ();
99
- for (auto r : results) {
100
- trt_results.push_back (r.toTensor ());
101
- }
102
- }
103
-
104
- for (size_t i = 0 ; i < trt_results.size (); i++) {
105
- almostEqual (jit_results[i], trt_results[i].reshape_as (jit_results[i]));
106
- }
97
+ // std::cout << "Running TRT module" << std::endl;
98
+ // torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
99
+ // std::vector<at::Tensor> trt_results;
100
+ // if (trt_results_ivalues.isTensor()) {
101
+ // trt_results.push_back(trt_results_ivalues.toTensor());
102
+ // } else {
103
+ // auto results = trt_results_ivalues.toTuple()->elements();
104
+ // for (auto r : results) {
105
+ // trt_results.push_back(r.toTensor());
106
+ // }
107
+ // }
108
+ //
109
+ // for (size_t i = 0; i < trt_results.size(); i++) {
110
+ // almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]));
111
+ // }
107
112
108
113
std::cout << " Converted Engine saved to /tmp/engine_converted_from_jit.trt" << std::endl;
109
114
0 commit comments