2424namespace trtorch {
2525namespace core {
2626
27- c10::FunctionSchema GenerateGraphSchema (torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
27+ c10::FunctionSchema GenerateGraphSchema (torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
2828
2929 std::vector<c10::Argument> args;
3030 for (auto in : g->inputs ()) {
3131 args.push_back (c10::Argument (in->debugName (), in->type ()));
3232 }
33-
33+
3434 std::vector<c10::Argument> returns;
3535 for (auto out : g->outputs ()) {
3636 returns.push_back (c10::Argument (out->debugName (), out->type ()));
3737 }
38-
38+
3939 return c10::FunctionSchema (method_name, method_name, args, returns);
4040}
4141
4242
4343void AddEngineToGraph (torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
44- execution::EngineID uid = execution::RegisterEngineFromSerializedEngine (serialized_engine);
44+ execution::EngineID uid = execution::RegisterEngineFromSerializedEngine (serialized_engine);
4545 auto schema = execution::GetEngineFunctionSchema (uid);
4646 auto num_io = execution::GetEngineIO (uid);
4747
@@ -53,58 +53,42 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
5353 in_val->setType (c10::TensorType::get ());
5454 graph_inputs.push_back (in_val);
5555 }
56-
56+
5757 auto engine_node = g->create (c10::Symbol::fromQualString (schema.name ()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second );
5858 g->block ()->appendNode (engine_node);
5959
6060 for (auto o : engine_node->outputs ()) {
6161 g->registerOutput (o);
6262 }
63-
63+
6464 return ;
6565}
6666
6767bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod,
6868 std::string method_name) {
69- auto g = mod.get_method (method_name).graph ();
70- // Go through PyTorch Lowering to simplify graph and extract weight parameters
71- auto graph_and_parameters = torch::jit::LowerGraph (*g, mod._ivalue ());
72-
73- g = graph_and_parameters.first ;
74-
75- // Go through TRTorch Lowering to reformat graph to be conversion friendly
76- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
77- lowering::LowerGraph (g);
78-
69+ // Go through Lowering to simplify graph and extract weight parameters
70+ auto graph_and_parameters = lowering::Lower (mod, method_name);
71+
72+ auto g = graph_and_parameters.first ;
7973 auto params = graph_and_parameters.second ;
8074 auto named_params = conversion::get_named_params (g->inputs (), params);
8175 LOG_DEBUG (*g << " (CheckMethodOperatorSupport)\n " );
82-
83- // Is this necessary?
84- lowering::LowerBlock (g->block ());
85-
76+
8677 return conversion::VerifyConverterSupportForBlock (g->block ());
8778}
8879
8980std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod,
9081 std::string method_name,
9182 conversion::ExtraInfo cfg) {
92- auto g = mod.get_method (method_name).graph ();
93- // Go through PyTorch Lowering to simplify graph and extract weight parameters
94- auto graph_and_parameters = torch::jit::LowerGraph (*g, mod._ivalue ());
95-
96- g = graph_and_parameters.first ;
97-
98- // Go through TRTorch Lowering to reformat graph to be conversion friendly
99- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
100- lowering::LowerGraph (g);
101-
83+ // Go through Lowering to simplify graph and extract weight parameters
84+ auto graph_and_parameters = lowering::Lower (mod, method_name);
85+
86+ auto g = graph_and_parameters.first ;
10287 auto params = graph_and_parameters.second ;
10388 auto named_params = conversion::get_named_params (g->inputs (), params);
89+
10490 LOG_INFO (*g << " (CompileGraph)\n " );
105-
106- // Is this necessary?
107- lowering::LowerBlock (g->block ());
91+
10892 auto engine = ConvertBlockToEngine (g->block (), cfg, named_params);
10993 return std::move (engine);
11094}
@@ -128,7 +112,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
128112
129113 return new_mod;
130114}
131-
115+
132116} // namespace core
133117} // namespace trtorch
134118
0 commit comments