Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][PyTorch][Docs] Use torch.export insteamd of fx.symbolic_trace for tutorial #17436

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions docs/get_started/tutorials/ir_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
# below.

import torch
from torch import fx, nn
from tvm.relax.frontend.torch import from_fx
from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program

######################################################################
# Import from existing models
Expand All @@ -67,13 +68,15 @@ def forward(self, x):
return x


# Give the input shape and data type
input_info = [((1, 784), "float32")]
# Give an example argument to torch.export
example_args = (torch.randn(1, 784, dtype=torch.float32),)

# Convert the model to IRModule
with torch.no_grad():
torch_fx_model = fx.symbolic_trace(TorchModel())
mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
exported_program = export(TorchModel().eval(), example_args)
mod_from_torch = from_exported_program(
exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
)

mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
Expand Down
18 changes: 8 additions & 10 deletions docs/how_to/tutorials/e2e_opt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
import os
import numpy as np
import torch
from torch import fx
from torch.export import export
from torchvision.models.resnet import ResNet18_Weights, resnet18

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()

######################################################################
# Review Overall Flow
Expand All @@ -63,21 +63,19 @@
# Convert the model to IRModule
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further
# optimization. Besides the model, we also need to provide the input shape and data type.
# optimization.

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
from tvm.relax.frontend.torch import from_exported_program

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)

# Give the input shape and data type
input_info = [((1, 3, 224, 224), "float32")]
# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)

# Convert the model to IRModule
with torch.no_grad():
torch_fx_model = fx.symbolic_trace(torch_model)
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
exported_program = export(torch_model, example_args)
mod = from_exported_program(exported_program, keep_params_as_input=True)

mod, params = relax.frontend.detach_params(mod)
mod.show()
Expand Down
71 changes: 37 additions & 34 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):

from torch import fx

def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[List[relax.Var], List[relax.Var]]:
"""Create relax input vars."""
parameters_buffers_constants = []
user_inputs = []
for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
shape = exported_program.tensor_constants[spec.target].shape
torch_dtype = exported_program.tensor_constants[spec.target].dtype
elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target):
if node.name == name_hint:
shape = node.meta["tensor_meta"].shape
torch_dtype = node.meta["tensor_meta"].dtype
break
else:
# PARAMETER or BUFFER
shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype

dtype = self._convert_data_type(torch_dtype)
relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype))
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
user_inputs.append(relax_var)
else:
parameters_buffers_constants.append(relax_var)

return parameters_buffers_constants, user_inputs

########## Unary Ops ##########

def _hardtanh(self, node: fx.Node) -> relax.Expr:
Expand Down Expand Up @@ -178,6 +147,8 @@ def _slice(self, node: fx.Node) -> relax.Var:
stride = [node.args[4] if len(node.args) > 4 else 1]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))

########## Others ##########

def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
Expand Down Expand Up @@ -293,6 +264,37 @@ def create_convert_map(
"getitem": self._getitem,
}

def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
shape = exported_program.tensor_constants[spec.target].shape
torch_dtype = exported_program.tensor_constants[spec.target].dtype
elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target):
if node.name == name_hint:
shape = node.meta["tensor_meta"].shape
torch_dtype = node.meta["tensor_meta"].dtype
break
else:
# PARAMETER or BUFFER
shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype

dtype = self._convert_data_type(torch_dtype)
relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype))
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
user_inputs[name_hint] = relax_var
else:
parameters_buffers_constants[name_hint] = relax_var

return parameters_buffers_constants, user_inputs

def from_exported_program(
self,
exported_program: torch.export.ExportedProgram,
Expand All @@ -305,7 +307,8 @@ def from_exported_program(

# Create input variables.
parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program)
inputs_vars = parameter_buffer_constant_vars + user_input_vars
inputs_vars = user_input_vars.copy()
inputs_vars.update(parameter_buffer_constant_vars)

# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
Expand All @@ -314,7 +317,7 @@ def from_exported_program(

nodes: List[fx.Node] = exported_program.graph.nodes
with self.block_builder.function(
name=func_name, params=inputs_vars.copy(), attrs=func_attrs
name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs
):
output = None
with self.block_builder.dataflow():
Expand All @@ -325,7 +328,7 @@ def from_exported_program(
# Ignore sym input
continue

self.env[node] = inputs_vars.pop(0)
self.env[node] = inputs_vars[node.name]
elif node.op == "output":
args = self.retrieve_args(node)
assert len(args) == 1
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -3550,9 +3550,9 @@ def forward(self, input):
class expected1:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
conv_bias: R.Tensor((6,), dtype="float32"),
input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
R.func_attr({"num_input": 1})
# block 0
Expand Down Expand Up @@ -3586,7 +3586,7 @@ def main(
params = params["main"]

assert len(params) == len(func.params) - 1
for param_var, param_ndarray in zip(func.params[:-1], params):
for param_var, param_ndarray in zip(func.params[1:], params):
assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape
assert param_var.struct_info.dtype == param_ndarray.dtype

Expand Down
Loading