diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3018b0db771d..6d880ab90dc2 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -21,7 +21,7 @@ import abc from functools import reduce import math -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union, List from tvm import relax @@ -103,6 +103,16 @@ def _retrieve_args(self, node): else: return node + def _check_unsupported_func_type(self, nodes: List[fx.Node]): + missing_func_types = list( + { + node.target.__name__ + for node in nodes + if node.op == "call_function" and node.target.__name__ not in self.convert_map + } + ) + assert not missing_func_types, f"Unsupported function types {missing_func_types}" + ########## Unary Ops ########## def _unary_op(self, op: Callable) -> Callable: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7b9587b67561..8f6418891bb1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -518,23 +518,16 @@ def from_exported_program( func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None nodes: List[fx.Node] = exported_program.graph.nodes + + # Find all the missing function types + self._check_unsupported_func_type(nodes) + with self.block_builder.function( name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs ): output = None with self.block_builder.dataflow(): - # Find all the missing function types - missing_func_types = list( - { - node.target.__name__ - for node in nodes - if node.op == "call_function" - and node.target.__name__ not in self.convert_map - } - ) - assert not missing_func_types, f"Unsupported function types {missing_func_types}" - # Translate the model. for node in nodes: if node.op == "placeholder": diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d24d67105e46..594344fef89f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -848,21 +848,13 @@ def from_fx( else: func_attrs = None + # Find all the missing function types + self._check_unsupported_func_type(graph.nodes) + with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): output = None with self.block_builder.dataflow(): - # Find all the missing function types - missing_func_types = list( - { - node.target.__name__ - for node in graph.nodes - if node.op == "call_function" - and node.target.__name__ not in self.convert_map - } - ) - assert not missing_func_types, f"Unsupported function types {missing_func_types}" - # Translate model parameters. for _, param in model.named_parameters(): shape = param.data.shape