From 780dd5532c5321310adc48106568427a4dddac84 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Oct 2024 14:42:30 +0900 Subject: [PATCH 1/4] use torch.export --- docs/get_started/tutorials/ir_module.py | 15 +++++++++------ docs/how_to/tutorials/e2e_opt_model.py | 18 ++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/get_started/tutorials/ir_module.py b/docs/get_started/tutorials/ir_module.py index f813333bafc3..0a825c3da757 100644 --- a/docs/get_started/tutorials/ir_module.py +++ b/docs/get_started/tutorials/ir_module.py @@ -40,8 +40,9 @@ # below. import torch -from torch import fx, nn -from tvm.relax.frontend.torch import from_fx +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program ###################################################################### # Import from existing models @@ -67,13 +68,15 @@ def forward(self, x): return x -# Give the input shape and data type -input_info = [((1, 784), "float32")] +# Give an example argument to torch.export +example_args = (torch.randn(1, 784, dtype=torch.float32),) # Convert the model to IRModule with torch.no_grad(): - torch_fx_model = fx.symbolic_trace(TorchModel()) - mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + exported_program = export(TorchModel().eval(), example_args) + mod_from_torch = from_exported_program( + exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True + ) mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch) # Print the IRModule diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index 5c11439e1635..532fb89fd3bc 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -34,10 +34,10 @@ import os import numpy as np import torch -from torch import fx +from torch.export import export from torchvision.models.resnet import ResNet18_Weights, resnet18 -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval() ###################################################################### # Review Overall Flow @@ -63,21 +63,19 @@ # Convert the model to IRModule # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further -# optimization. Besides the model, we also need to provide the input shape and data type. +# optimization. import tvm from tvm import relax -from tvm.relax.frontend.torch import from_fx +from tvm.relax.frontend.torch import from_exported_program -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) - -# Give the input shape and data type -input_info = [((1, 3, 224, 224), "float32")] +# Give an example argument to torch.export +example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),) # Convert the model to IRModule with torch.no_grad(): - torch_fx_model = fx.symbolic_trace(torch_model) - mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + exported_program = export(torch_model, example_args) + mod = from_exported_program(exported_program, keep_params_as_input=True) mod, params = relax.frontend.detach_params(mod) mod.show() From 09953c077ddb89d2430d32eb8ca31b8332aae9ec Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Oct 2024 14:46:45 +0900 Subject: [PATCH 2/4] in order to make interface consistent, user inputs should be placed first --- .../torch/exported_program_translator.py | 17 +++++++++-------- .../test_frontend_from_exported_program.py | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1401a0bcef3a..0fe227dc111c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -36,10 +36,10 @@ class ExportedProgramImporter(BaseFXGraphImporter): def create_input_vars( self, exported_program: torch.export.ExportedProgram - ) -> Tuple[List[relax.Var], List[relax.Var]]: + ) -> Tuple[OrderedDict[str, relax.Var], OrderedDict[str, relax.Var]]: """Create relax input vars.""" - parameters_buffers_constants = [] - user_inputs = [] + parameters_buffers_constants = OrderedDict() + user_inputs = OrderedDict() for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: @@ -59,9 +59,9 @@ def create_input_vars( dtype = self._convert_data_type(torch_dtype) relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - user_inputs.append(relax_var) + user_inputs[name_hint] = relax_var else: - parameters_buffers_constants.append(relax_var) + parameters_buffers_constants[name_hint] = relax_var return parameters_buffers_constants, user_inputs @@ -305,7 +305,8 @@ def from_exported_program( # Create input variables. parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) - inputs_vars = parameter_buffer_constant_vars + user_input_vars + inputs_vars = user_input_vars.copy() + inputs_vars.update(parameter_buffer_constant_vars) # Initialize the block builder with a function and a dataflow block. self.block_builder = relax.BlockBuilder() @@ -314,7 +315,7 @@ def from_exported_program( nodes: List[fx.Node] = exported_program.graph.nodes with self.block_builder.function( - name=func_name, params=inputs_vars.copy(), attrs=func_attrs + name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs ): output = None with self.block_builder.dataflow(): @@ -325,7 +326,7 @@ def from_exported_program( # Ignore sym input continue - self.env[node] = inputs_vars.pop(0) + self.env[node] = inputs_vars[node.name] elif node.op == "output": args = self.retrieve_args(node) assert len(args) == 1 diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 65890ff6971b..0d8425fc7f30 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3550,9 +3550,9 @@ def forward(self, input): class expected1: @R.function def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), conv_bias: R.Tensor((6,), dtype="float32"), - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): R.func_attr({"num_input": 1}) # block 0 @@ -3586,7 +3586,7 @@ def main( params = params["main"] assert len(params) == len(func.params) - 1 - for param_var, param_ndarray in zip(func.params[:-1], params): + for param_var, param_ndarray in zip(func.params[1:], params): assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape assert param_var.struct_info.dtype == param_ndarray.dtype From 9da659c3c0b473eebb4ea5388182e053b696090e Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Oct 2024 14:46:49 +0900 Subject: [PATCH 3/4] chore --- .../torch/exported_program_translator.py | 64 ++++++++++--------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0fe227dc111c..91545ff2b919 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter): from torch import fx - def create_input_vars( - self, exported_program: torch.export.ExportedProgram - ) -> Tuple[OrderedDict[str, relax.Var], OrderedDict[str, relax.Var]]: - """Create relax input vars.""" - parameters_buffers_constants = OrderedDict() - user_inputs = OrderedDict() - for spec in exported_program.graph_signature.input_specs: - name_hint = spec.arg.name - if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: - shape = exported_program.tensor_constants[spec.target].shape - torch_dtype = exported_program.tensor_constants[spec.target].dtype - elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): - if node.name == name_hint: - shape = node.meta["tensor_meta"].shape - torch_dtype = node.meta["tensor_meta"].dtype - break - else: - # PARAMETER or BUFFER - shape = exported_program.state_dict[spec.target].shape - torch_dtype = exported_program.state_dict[spec.target].dtype - - dtype = self._convert_data_type(torch_dtype) - relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - user_inputs[name_hint] = relax_var - else: - parameters_buffers_constants[name_hint] = relax_var - - return parameters_buffers_constants, user_inputs - ########## Unary Ops ########## def _hardtanh(self, node: fx.Node) -> relax.Expr: @@ -178,6 +147,8 @@ def _slice(self, node: fx.Node) -> relax.Var: stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + ########## Others ########## + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -293,6 +264,37 @@ def create_convert_map( "getitem": self._getitem, } + def create_input_vars( + self, exported_program: torch.export.ExportedProgram + ) -> Tuple[OrderedDict[str, relax.Var], OrderedDict[str, relax.Var]]: + """Create relax input vars.""" + parameters_buffers_constants = OrderedDict() + user_inputs = OrderedDict() + for spec in exported_program.graph_signature.input_specs: + name_hint = spec.arg.name + if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: + shape = exported_program.tensor_constants[spec.target].shape + torch_dtype = exported_program.tensor_constants[spec.target].dtype + elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): + if node.name == name_hint: + shape = node.meta["tensor_meta"].shape + torch_dtype = node.meta["tensor_meta"].dtype + break + else: + # PARAMETER or BUFFER + shape = exported_program.state_dict[spec.target].shape + torch_dtype = exported_program.state_dict[spec.target].dtype + + dtype = self._convert_data_type(torch_dtype) + relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + user_inputs[name_hint] = relax_var + else: + parameters_buffers_constants[name_hint] = relax_var + + return parameters_buffers_constants, user_inputs + def from_exported_program( self, exported_program: torch.export.ExportedProgram, From 19046da21d24aa32bdd9156d255c3c0b0fc0ad36 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Oct 2024 17:48:32 +0900 Subject: [PATCH 4/4] fix --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 91545ff2b919..7bcd20c462bd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -266,7 +266,7 @@ def create_convert_map( def create_input_vars( self, exported_program: torch.export.ExportedProgram - ) -> Tuple[OrderedDict[str, relax.Var], OrderedDict[str, relax.Var]]: + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict()