diff --git a/docs/get_started/tutorials/ir_module.py b/docs/get_started/tutorials/ir_module.py index f813333bafc3..0a825c3da757 100644 --- a/docs/get_started/tutorials/ir_module.py +++ b/docs/get_started/tutorials/ir_module.py @@ -40,8 +40,9 @@ # below. import torch -from torch import fx, nn -from tvm.relax.frontend.torch import from_fx +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program ###################################################################### # Import from existing models @@ -67,13 +68,15 @@ def forward(self, x): return x -# Give the input shape and data type -input_info = [((1, 784), "float32")] +# Give an example argument to torch.export +example_args = (torch.randn(1, 784, dtype=torch.float32),) # Convert the model to IRModule with torch.no_grad(): - torch_fx_model = fx.symbolic_trace(TorchModel()) - mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + exported_program = export(TorchModel().eval(), example_args) + mod_from_torch = from_exported_program( + exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True + ) mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch) # Print the IRModule diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index 5c11439e1635..532fb89fd3bc 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -34,10 +34,10 @@ import os import numpy as np import torch -from torch import fx +from torch.export import export from torchvision.models.resnet import ResNet18_Weights, resnet18 -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval() ###################################################################### # Review Overall Flow @@ -63,21 +63,19 @@ # Convert the model to IRModule # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further -# optimization. Besides the model, we also need to provide the input shape and data type. +# optimization. import tvm from tvm import relax -from tvm.relax.frontend.torch import from_fx +from tvm.relax.frontend.torch import from_exported_program -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) - -# Give the input shape and data type -input_info = [((1, 3, 224, 224), "float32")] +# Give an example argument to torch.export +example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),) # Convert the model to IRModule with torch.no_grad(): - torch_fx_model = fx.symbolic_trace(torch_model) - mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + exported_program = export(torch_model, example_args) + mod = from_exported_program(exported_program, keep_params_as_input=True) mod, params = relax.frontend.detach_params(mod) mod.show() diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1401a0bcef3a..7bcd20c462bd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter): from torch import fx - def create_input_vars( - self, exported_program: torch.export.ExportedProgram - ) -> Tuple[List[relax.Var], List[relax.Var]]: - """Create relax input vars.""" - parameters_buffers_constants = [] - user_inputs = [] - for spec in exported_program.graph_signature.input_specs: - name_hint = spec.arg.name - if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: - shape = exported_program.tensor_constants[spec.target].shape - torch_dtype = exported_program.tensor_constants[spec.target].dtype - elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): - if node.name == name_hint: - shape = node.meta["tensor_meta"].shape - torch_dtype = node.meta["tensor_meta"].dtype - break - else: - # PARAMETER or BUFFER - shape = exported_program.state_dict[spec.target].shape - torch_dtype = exported_program.state_dict[spec.target].dtype - - dtype = self._convert_data_type(torch_dtype) - relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - user_inputs.append(relax_var) - else: - parameters_buffers_constants.append(relax_var) - - return parameters_buffers_constants, user_inputs - ########## Unary Ops ########## def _hardtanh(self, node: fx.Node) -> relax.Expr: @@ -178,6 +147,8 @@ def _slice(self, node: fx.Node) -> relax.Var: stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + ########## Others ########## + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -293,6 +264,37 @@ def create_convert_map( "getitem": self._getitem, } + def create_input_vars( + self, exported_program: torch.export.ExportedProgram + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: + """Create relax input vars.""" + parameters_buffers_constants = OrderedDict() + user_inputs = OrderedDict() + for spec in exported_program.graph_signature.input_specs: + name_hint = spec.arg.name + if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: + shape = exported_program.tensor_constants[spec.target].shape + torch_dtype = exported_program.tensor_constants[spec.target].dtype + elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): + if node.name == name_hint: + shape = node.meta["tensor_meta"].shape + torch_dtype = node.meta["tensor_meta"].dtype + break + else: + # PARAMETER or BUFFER + shape = exported_program.state_dict[spec.target].shape + torch_dtype = exported_program.state_dict[spec.target].dtype + + dtype = self._convert_data_type(torch_dtype) + relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + user_inputs[name_hint] = relax_var + else: + parameters_buffers_constants[name_hint] = relax_var + + return parameters_buffers_constants, user_inputs + def from_exported_program( self, exported_program: torch.export.ExportedProgram, @@ -305,7 +307,8 @@ def from_exported_program( # Create input variables. parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) - inputs_vars = parameter_buffer_constant_vars + user_input_vars + inputs_vars = user_input_vars.copy() + inputs_vars.update(parameter_buffer_constant_vars) # Initialize the block builder with a function and a dataflow block. self.block_builder = relax.BlockBuilder() @@ -314,7 +317,7 @@ def from_exported_program( nodes: List[fx.Node] = exported_program.graph.nodes with self.block_builder.function( - name=func_name, params=inputs_vars.copy(), attrs=func_attrs + name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs ): output = None with self.block_builder.dataflow(): @@ -325,7 +328,7 @@ def from_exported_program( # Ignore sym input continue - self.env[node] = inputs_vars.pop(0) + self.env[node] = inputs_vars[node.name] elif node.op == "output": args = self.retrieve_args(node) assert len(args) == 1 diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 65890ff6971b..0d8425fc7f30 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3550,9 +3550,9 @@ def forward(self, input): class expected1: @R.function def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), conv_bias: R.Tensor((6,), dtype="float32"), - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): R.func_attr({"num_input": 1}) # block 0 @@ -3586,7 +3586,7 @@ def main( params = params["main"] assert len(params) == len(func.params) - 1 - for param_var, param_ndarray in zip(func.params[:-1], params): + for param_var, param_ndarray in zip(func.params[1:], params): assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape assert param_var.struct_info.dtype == param_ndarray.dtype