@@ -150,13 +150,16 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
150150 torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
151151 std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
152152 for (const torch::jit::script::Method& method : mod.get_methods ()) {
153- auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
154- auto new_g = std::make_shared<torch::jit::Graph>();
155- AddEngineToGraph (new_mod, new_g, engine);
156- auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
157- auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
158- new_mod.type ()->addMethod (new_method);
159- new_method->setSchema (schema);
153+ // Don't convert hidden methods
154+ if (method.name ().rfind (" _" , 0 )) {
155+ auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
156+ auto new_g = std::make_shared<torch::jit::Graph>();
157+ AddEngineToGraph (new_mod, new_g, engine);
158+ auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
159+ auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
160+ new_mod.type ()->addMethod (new_method);
161+ new_method->setSchema (schema);
162+ }
160163 }
161164
162165 return new_mod;
0 commit comments