@@ -16,6 +16,34 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
16
16
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-5 ));
17
17
}
18
18
19
+ TEST_P (ModuleTests, ModuleToModuleIsClose) {
20
+ std::vector<at::Tensor> inputs;
21
+ std::vector<torch::jit::IValue> inputs_ivalues;
22
+ for (auto in_shape : input_shapes) {
23
+ inputs.push_back (at::randint (5 , in_shape, {at::kCUDA }));
24
+ inputs_ivalues.push_back (inputs[inputs.size () - 1 ].clone ());
25
+ }
26
+
27
+ torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward (mod, inputs_ivalues);
28
+ std::vector<at::Tensor> jit_results;
29
+ jit_results.push_back (jit_results_ivalues.toTensor ());
30
+
31
+ auto forward_graph = mod.get_method (" forward" );
32
+ std::vector<c10::ArrayRef<int64_t >> input_ranges;
33
+ for (auto in : inputs) {
34
+ input_ranges.push_back (in.sizes ());
35
+ }
36
+
37
+ auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , input_ranges);
38
+ auto trt_mod = trtorch::EmbedEngineInNewModule (engine);
39
+
40
+ torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward (mod, inputs_ivalues);
41
+ std::vector<at::Tensor> trt_results;
42
+ trt_results.push_back (trt_results_ivalues.toTensor ());
43
+
44
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-5 ));
45
+ }
46
+
19
47
INSTANTIATE_TEST_SUITE_P (
20
48
ModuleAsEngineForwardIsCloseSuite,
21
49
ModuleTests,
0 commit comments