Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Upsampling op support and enable registering a user defined op conversion map #4961

Merged
merged 10 commits into from
Mar 1, 2020
80 changes: 72 additions & 8 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
from packaging import version

import numpy as np

Expand All @@ -31,6 +30,7 @@
from .. import op as _op
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value

__all__ = ["from_pytorch"]

Expand Down Expand Up @@ -614,6 +614,61 @@ def _impl(inputs, input_types):
return _op.tensor.sqrt(data)
return _impl

def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl

def _to():
def _impl(inputs, input_types):
data = inputs[0]
if inputs[3] in ["cpu", "cuda"]:
return data
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
3: int,
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
masahi marked this conversation as resolved.
Show resolved Hide resolved
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data

return _impl

def _upsample(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int))
for res in infer_res]

data = inputs[0]

if len(inputs) > 2:
align_corners = inputs[2]
else:
align_corners = False

if align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"

return _op.image.resize(data, out_size, "NCHW", method, coord_trans)

return _impl

# Helper functions for operator implementation

def _convert_data_type(input_type):
Expand Down Expand Up @@ -686,7 +741,7 @@ def _convert_elemwise_input(data, input_type):
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::to" : _identity(),
"aten::to" : _to(),
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
Expand Down Expand Up @@ -729,15 +784,18 @@ def _convert_elemwise_input(data, input_type):
"aten::permute" : _transpose(),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
"aten::sqrt" : _sqrt()
"aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
}


def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(graph)


def _is_int_seq(seq):
Expand Down Expand Up @@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):

def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
return set(node.kind() for node in graph.nodes())


def get_graph_input_names(script_module):
Expand All @@ -997,7 +1054,7 @@ def get_graph_input_names(script_module):
return ir_inputs[1:] # remove self at the 0th arg


def from_pytorch(script_module, input_shapes):
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.

Expand All @@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes):
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above

custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above

Returns
-------
mod : tvm.relay.Module
Expand All @@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes):
"""
graph = script_module.graph.copy()
_run_jit_passes(graph)

if custom_convert_map:
_convert_map.update(custom_convert_map)

op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)

Expand Down
Loading