Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch, CI] Update to PyTorch 1.7 #6811

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. -->
ci_lint = "tlcpack/ci-lint:v0.62"
ci_gpu = "tlcpack/ci-gpu:v0.71"
ci_gpu = "tlcpack/ci-gpu:v0.72"
ci_cpu = "tlcpack/ci-cpu:v0.71"
ci_wasm = "tlcpack/ci-wasm:v0.70"
ci_i386 = "tlcpack/ci-i386:v0.71"
Expand Down
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pip3 install onnxruntime==1.0.0
# not expose that in the wheel!!!
pip3 install future

pip3 install torch==1.4.0 torchvision==0.5.0
pip3 install torch==1.7.0 torchvision==0.8.1
masahi marked this conversation as resolved.
Show resolved Hide resolved
60 changes: 35 additions & 25 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import logging
import sys
import math

import numpy as np

Expand Down Expand Up @@ -168,7 +169,6 @@ def _min():

def _unary(name):
def _impl(inputs, input_types):
input_type = input_types[0]
# this is just to ensure tensor input
(data,) = _pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name)(data)
Expand Down Expand Up @@ -1552,7 +1552,7 @@ def _impl(inputs, input_types):
axis = None
keepdims = False
if len(inputs) > 2:
axis = inputs[1]
axis = inputs[1] if len(inputs[1]) > 0 else None
keepdims = bool(inputs[2])

return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims))
Expand Down Expand Up @@ -1847,18 +1847,33 @@ def _impl(inputs, input_types):
return _impl


def _upsample(method, prelude):
def _impl(inputs, input_types):
out_size = []
def _get_upsample_out_size(inputs, method):
# This assumes a static shape
out_size = []
if inputs[1] is not None:
for size in inputs[1]:
if not isinstance(size, int):
out_size.append(int(_infer_value(size, {}).asnumpy()))
else:
out_size.append(size)
else:
scale_index = 3 if method in ["bilinear", "trilinear"] else 2
scales = inputs[scale_index]
assert scales is not None, "neither out size nor scale provided"
assert isinstance(scales, list)
ishape = _infer_shape(inputs[0])
for i, scale in enumerate(scales):
out_size.append(int(math.floor(float(ishape[2 + i]) * scale)))

return out_size


def _upsample(method, prelude):
def _impl(inputs, input_types):
data = inputs[0]
out_size = _get_upsample_out_size(inputs, method)

if len(inputs) > 2:
if len(inputs) > 2 and method == "bilinear":
align_corners = inputs[2]
else:
align_corners = False
Expand All @@ -1874,35 +1889,24 @@ def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if _is_quantized_tensor(data, prelude):
# Torch version > 1.4 changed upsampling API
if is_version_greater_than("1.4.0"):
num_inputs = 7
else:
num_inputs = 5

assert len(inputs) == num_inputs, "Input quant param not found in op inputs"

# input qparams are manually appended by us
assert isinstance(inputs[-2], float)
assert isinstance(inputs[-1], int)
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)

return func(data)

return _impl


def _upsample3d(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif _is_int_seq(inputs[1]):
out_size = inputs[1]
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]

data = inputs[0]
out_size = _get_upsample_out_size(inputs, method)

if len(inputs) > 2:
if len(inputs) > 2 and method == "trilinear":
masahi marked this conversation as resolved.
Show resolved Hide resolved
align_corners = inputs[2]
else:
align_corners = False
Expand Down Expand Up @@ -1983,8 +1987,7 @@ def _impl(inputs, input_types):

def _logical_not():
def _impl(inputs, input_types):
data = inputs[0]

data = _wrap_const(inputs[0])
return _op.logical_not(_op.cast(data, "bool"))

return _impl
Expand Down Expand Up @@ -2732,6 +2735,7 @@ def _get_convert_map(prelude, default_dtype):
"aten::empty": _empty(),
"aten::bincount": _bincount(),
"aten::scatter_add": _scatter_add(),
"aten::__not__": _logical_not(),
}
return convert_map

Expand Down Expand Up @@ -2798,6 +2802,7 @@ def _report_missing_conversion(op_names, convert_map):
"prim::ListUnpack",
"prim::TupleConstruct",
"prim::TupleUnpack",
"prim::RaiseException",
"prim::If",
"prim::Loop",
]
Expand Down Expand Up @@ -2903,6 +2908,8 @@ def _get_operator_nodes(nodes):
ops = []
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() == 0:
continue
if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
else:
Expand Down Expand Up @@ -3286,6 +3293,9 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
else:
unpacked = _unpack_tuple(inputs[0])
outputs.update(zip(_get_output_names(op_node), unpacked))
elif operator == "prim::prim::RaiseException":
logging.warning("raising exceptions is ignored")
outputs[node_name] = None
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype)
outputs[node_name] = if_out
Expand Down
3 changes: 2 additions & 1 deletion tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def get_imagenet_input():
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
# tracing quantized googlenet broken as of v1.6
# ("googlenet", qgooglenet(pretrained=True), per_channel),
]

results = []
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,7 +2535,7 @@ def test_forward_linspace():

class Linspace1(Module):
def forward(self, *args):
return torch.linspace(5, 10)
return torch.linspace(5, 10, steps=100)

class Linspace2(Module):
def forward(self, *args):
Expand All @@ -2559,7 +2559,7 @@ def forward(self, *args):

class Linspace7(Module):
def forward(self, *args):
return torch.linspace(1, 4, dtype=torch.float32)
return torch.linspace(1, 4, steps=100, dtype=torch.float32)

class Linspace8(Module):
def forward(self, *args):
Expand Down
6 changes: 3 additions & 3 deletions tutorials/frontend/deploy_object_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@

.. code-block:: bash

pip install torch==1.4.0
pip install torchvision==0.5.0
pip install torch==1.7.0
pip install torchvision==0.8.1

or please refer to official site
https://pytorch.org/get-started/locally/

PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.

Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
be unstable.
"""

Expand Down
6 changes: 3 additions & 3 deletions tutorials/frontend/from_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@

.. code-block:: bash

pip install torch==1.4.0
pip install torchvision==0.5.0
pip install torch==1.7.0
pip install torchvision==0.8.1

or please refer to official site
https://pytorch.org/get-started/locally/

PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.

Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
be unstable.
"""

Expand Down