Skip to content

Commit 9f006d5

Browse files
committed
fix: Restrict TRTorch to compile only forward methods
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 930d582 commit 9f006d5

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

core/compiler.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
183183
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
184184
for (const torch::jit::script::Method& method : mod.get_methods()) {
185185
// Don't convert hidden methods
186-
if (method.name().rfind("_", 0)) {
186+
if (method.name().compare("forward")==0) {
187187
auto new_g = std::make_shared<torch::jit::Graph>();
188188
auto graph_and_parameters = lowering::Lower(mod, method.name());
189189

@@ -257,7 +257,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
257257
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
258258
for (const torch::jit::script::Method& method : mod.get_methods()) {
259259
// Don't convert hidden methods
260-
if (method.name().rfind("_", 0)) {
260+
//
261+
if (method.name().compare("forward")==0) {
261262
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
262263
auto new_g = std::make_shared<torch::jit::Graph>();
263264
AddEngineToGraph(new_mod, new_g, engine);

0 commit comments

Comments
 (0)