Skip to content

Commit

Permalink
[TOPI, Relay] Add half_pixel option to Resize op (apache#4610)
Browse files Browse the repository at this point in the history
* add onnx resize converter

* update frontends

* updating topi

* adding onnx resize tests

* fixed NHWC test by casting size dtype to int32

* fix tests

* fix lint

* update existing test cases

* fix tensorflow frontend

* fix lint

* remove NHWC stuff

* update topi resize test for half_pixel

* update doc

* fix doc

* remove onnx resize bits
  • Loading branch information
masahi authored and alexwong committed Feb 28, 2020
1 parent 03c1e45 commit 8bb4314
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 74 deletions.
9 changes: 6 additions & 3 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
Array<IndexExpr> size;
std::string layout;
std::string method;
bool align_corners;
std::string coordinate_transformation_mode;
DataType out_dtype;

TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
Expand All @@ -52,8 +52,11 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
"nearest_neighbor - Nearest Neighbor"
"bilinear - Bilinear Interpolation"
"bicubic - Bicubic Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(true)
.describe("Should be true to preserve the values at the corner pixels");
TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel")
.describe("Describes how to transform the coordinate in the resized tensor"
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type.");
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def _mx_resize(inputs, attrs):
if scale_width is not None:
width = (scale_width * shape[3]).astype("int32")
size = (height, width)
return _op.image.resize(inputs[0], size, align_corners=True)
return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners")

def _mx_roi_pooling(inputs, attrs):
new_attrs = {}
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,7 @@ class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])


class Expand(OnnxOpConverter):
""" Operator converter for Expand.
"""
Expand Down Expand Up @@ -1138,6 +1139,7 @@ def expand_shape(in_shape, shape):
shape = expand_shape(in_shape, shape)
return _op.broadcast_to(inputs[0], shape=tuple(shape))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1263,7 +1265,7 @@ def _get_convert_map(opset):
'Tile': Tile.get_converter(opset),
'Erf': Erf.get_converter(opset),
'Where': Where.get_converter(opset),
'Or': Or.get_converter(opset)
'Or': Or.get_converter(opset),
}


Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def _impl(inputs, attr, params):
raise tvm.error.OpAttributeUnImplemented(
'Attribute method=nearest is not supported')
else:
attrs['align_corners'] = True
attrs['coordinate_transformation_mode'] = 'align_corners'
attrs['method'] = 'bilinear'

out = None
Expand Down Expand Up @@ -632,6 +632,10 @@ def _impl(inputs, attr, params):
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'
if attr.pop('align_corners') is True:
attr['coordinate_transformation_mode'] = 'align_corners'
else:
attr['coordinate_transformation_mode'] = 'asymmetric'

# Ignore the new attributes from TF2.0, for now.
return AttrCvt(op_name='resize',
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ def _convert_resize(self, method, op):
align_corners = resize_options.AlignCorners()

# Use layout NHWC
out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners)
coord_trans = "align_corners" if align_corners else "asymmetric"
out = _op.image.resize(in_expr, target_size, "NHWC", method,
coordinate_transformation_mode=coord_trans)
return out

def convert_resize_bilinear(self, op):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def compute_resize(attrs, inputs, out_type, target):
size = attrs.size
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
coord_trans = attrs.coordinate_transformation_mode
out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)]
return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
11 changes: 7 additions & 4 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def resize(data,
size,
layout="NCHW",
method="bilinear",
align_corners=True,
coordinate_transformation_mode="half_pixel",
out_dtype=None):
"""Image resize operator.
Expand All @@ -48,8 +48,11 @@ def resize(data,
method : str, optional
Scale method to used [nearest_neighbor, bilinear, bicubic].
align_corners : int, optional
Should be true to preserve the values at the corner pixels
coordinate_transformation_mode : string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
[half_pixel, align_corners, asymmetric]
out_dtype : str, optional
Type to return. If left None returns the same type as input.
Expand All @@ -59,4 +62,4 @@ def resize(data,
result: relay.Expr
The resized result.
"""
return _make.resize(data, size, layout, method, align_corners, out_dtype)
return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
4 changes: 2 additions & 2 deletions src/relay/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ Expr MakeResize(Expr data,
Array<IndexExpr> size,
std::string layout,
std::string method,
bool align_corners,
std::string coordinate_transformation_mode,
DataType out_dtype) {
auto attrs = make_object<ResizeAttrs>();
attrs->size = std::move(size);
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->align_corners = align_corners;
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.resize");
return CallNode::make(op, {data}, Attrs(attrs), {});
Expand Down
17 changes: 0 additions & 17 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,6 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)


def verify_super_resolution_example():
verify_onnx_forward_impl(
super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))


def verify_squeezenet1_1():
verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))


def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))


def verify_resnet18():
verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))


def test_reshape():
in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_resize_infer_type():
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")

x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True)
z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners")
assert "size=" in z.astext()
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
Expand All @@ -57,7 +57,7 @@ def verify_resize(dshape, scale, method, layout):
else:
ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout)
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.image.resize(x, size, layout, method, True)
z = relay.image.resize(x, size, layout, method, "align_corners")
assert "size=" in z.astext()
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
Expand Down
35 changes: 25 additions & 10 deletions topi/python/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from .. import tag


def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None):
def resize(data, size, layout="NCHW", method="bilinear",
coordinate_transformation_mode="half_pixel", out_dtype=None):
"""Perform resize operation on the data.
Parameters
Expand All @@ -37,8 +38,11 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
layout: string, optional
"NCHW", "NHWC", or "NCHWc".
align_corners: Boolean, optional
To preserve the values at the corner pixels.
coordinate_transformation_mode: string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
method: {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for resizing.
Expand Down Expand Up @@ -66,12 +70,15 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
in_n, in_c, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_cc]

if align_corners:
if coordinate_transformation_mode == "align_corners":
y_ratio = (in_h - 1).astype('float') / (size[0] - 1)
x_ratio = (in_w - 1).astype('float') / (size[1] - 1)
else:
elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
y_ratio = (in_h).astype('float') / (size[0])
x_ratio = (in_w).astype('float') / (size[1])
else:
raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
coordinate_transformation_mode))

def _get_pixel(n, c, y, x, cc):
y = tvm.max(tvm.min(y, in_h - 1), 0)
Expand Down Expand Up @@ -109,7 +116,7 @@ def _nearest_neighbor(*indices):
in_y = y_ratio * y
in_x = x_ratio * x

if align_corners:
if coordinate_transformation_mode == "align_corners":
yint = tvm.round(in_y).astype('int32')
xint = tvm.round(in_x).astype('int32')
else:
Expand All @@ -127,8 +134,12 @@ def _lerp(A, B, t):
def _bilinear(*indices):
n, c, y, x, cc = _get_indices(*indices)

in_y = y_ratio * y
in_x = x_ratio * x
if coordinate_transformation_mode == "half_pixel":
in_y = y_ratio * (y + 0.5) - 0.5
in_x = x_ratio * (x + 0.5) - 0.5
else:
in_y = y_ratio * y
in_x = x_ratio * x

xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
Expand Down Expand Up @@ -158,8 +169,12 @@ def _cubic_kernel(A, B, C, D, t):
def _bicubic(*indices):
n, c, y, x, cc = _get_indices(*indices)

in_y = y_ratio * y
in_x = x_ratio * x
if coordinate_transformation_mode == "half_pixel":
in_y = y_ratio * (y + 0.5) - 0.5
in_x = x_ratio * (x + 0.5) - 0.5
else:
in_y = y_ratio * y
in_x = x_ratio * x

xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
Expand Down
3 changes: 2 additions & 1 deletion topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',

else:
raise ValueError("not support this layout {} yet".format(layout))
coord_trans = "align_corners" if align_corners else "asymmetric"
return topi.image.resize(data, out_shape, layout=layout,
method=method, align_corners=align_corners)
method=method, coordinate_transformation_mode=coord_trans)


def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
Expand Down
43 changes: 24 additions & 19 deletions topi/python/topi/testing/bilinear_resize_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import math
import numpy as np

def bilinear_resize_python(image, out_size, layout, align_corners=True):
def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"):
""" Bilinear scaling using python"""
(new_h, new_w) = out_size

Expand All @@ -30,32 +30,37 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True):
(batch, channel, h, w) = image.shape
scaled_image = np.ones((batch, channel, new_h, new_w))

if align_corners:
if coordinate_transformation_mode == "align_corners":
height_scale = np.float32(h-1) / np.float32(out_size[0]-1)
width_scale = np.float32(w-1) / np.float32(out_size[1]-1)
else:
height_scale = np.float32(h) / np.float32(out_size[0])
width_scale = np.float32(w) / np.float32(out_size[1])

def _lerp(A, B, t):
return A * (1.0 - t) + B * t

for b in range(batch):
for i in range(channel):
for j in range(new_h):
for k in range(new_w):
in_y = j * height_scale
y0 = math.floor(in_y)
y1 = min(math.ceil(in_y), h - 1)
y_lerp = in_y - y0

y0 = int(y0)
y1 = int(y1)

in_x = k * width_scale
x0 = math.floor(in_x)
x1 = min(math.ceil(in_x), w - 1)
x_lerp = in_x - x0
if coordinate_transformation_mode == "half_pixel":
in_y = (j + 0.5) * height_scale - 0.5
else:
in_y = j * height_scale
y0 = int(math.floor(in_y))
y1 = max(min(y0 + 1, h - 1), 0)
y0 = max(y0, 0)
y_lerp = in_y - math.floor(in_y)

x0 = int(x0)
x1 = int(x1)
if coordinate_transformation_mode == "half_pixel":
in_x = (k + 0.5) * width_scale - 0.5
else:
in_x = k * width_scale
x0 = int(math.floor(in_x))
x1 = max(min(x0 + 1, w - 1), 0)
x0 = max(x0, 0)
x_lerp = in_x - math.floor(in_x)

if layout == 'NHWC':
A = image[b][y0][x0][i]
Expand All @@ -68,10 +73,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True):
C = image[b][i][y1][x0]
D = image[b][i][y1][x1]

top = A + (B - A) * x_lerp
bottom = C + (D - C) * x_lerp
top = _lerp(A, B, x_lerp)
bottom = _lerp(C, D, x_lerp)

pixel = np.float32(top + (bottom - top) * y_lerp)
pixel = np.float32(_lerp(top, bottom, y_lerp))

if layout == 'NHWC':
scaled_image[b][j][k][i] = pixel
Expand Down
20 changes: 11 additions & 9 deletions topi/tests/python/test_topi_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from common import get_all_backend

def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"):
def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
layout='NCHW', coord_trans="align_corners", method="bilinear"):
if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype
Expand All @@ -37,11 +38,9 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
else:
raise NotImplementedError(
'Layout not supported {} '.format(layout))

B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method)

B = topi.image.resize(A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method)
if method == "bilinear":
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, coord_trans)
else:
scale_h = out_height / in_height
scale_w = out_width / in_width
Expand Down Expand Up @@ -70,14 +69,17 @@ def test_resize():
# Scale NCHW
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW')
# Scale NCHW + Align Corners
verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True)
verify_resize(6, 32, 64, 64, 20, 20, 'NCHW')
# Scale NHWC
verify_resize(4, 16, 32, 32, 50, 50, "NHWC")
# Scale NHWC + Align Corners
verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True)
verify_resize(6, 32, 64, 64, 20, 20, "NHWC")
# Nearest + Fractional
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False)
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False)
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', "asymmetric", method="nearest_neighbor")
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', "asymmetric", method="nearest_neighbor")
# half_pixel
verify_resize(4, 16, 16, 16, 32, 32, 'NCHW', "half_pixel", method="bilinear")
verify_resize(4, 16, 16, 16, 32, 32, 'NHWC', "half_pixel", method="bilinear")


def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width,
Expand Down
Loading

0 comments on commit 8bb4314

Please sign in to comment.