Skip to content

Commit

Permalink
Register Shape Func for Some Operators to Handle Dynamic Shapes (#5955)
Browse files Browse the repository at this point in the history
* Register Shape Func for Floor Operator

Register the shape function for `floor` operator. Otherwise, a bug will happen when input of floor is any.

* Register shape func for log

* add shape function for crop_and_size

* change import location

* add mirror_pad shape function

* add test cases for crop_and_resize and mirror_pad shape funcs

* support different layout

* fix pylint error

* fix pylint error

* add test for nchw layout

* block nchw test

* test for nchw

* use tvm.testing.assert_allclose instead

Co-authored-by: lisiyuan <lisiyuan@nucflow>
  • Loading branch information
lsy643 and lisiyuan authored Jul 23, 2020
1 parent 9d34eaa commit fe76196
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,5 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)
register_shape_func("fast_erf", False, elemwise_shape_func)
register_shape_func("floor", False, elemwise_shape_func)
register_shape_func("log", False, elemwise_shape_func)
31 changes: 31 additions & 0 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import

from tvm.te.hybrid import script
from tvm.runtime import convert

import topi
from topi.util import get_const_tuple
from .. import op as reg
Expand Down Expand Up @@ -64,6 +67,34 @@ def compute_crop_and_resize(attrs, inputs, out_type):

reg.register_injective_schedule("image.crop_and_resize")

@script
def _crop_and_resize_func(image_shape, boxes_shape, crop_size,
height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
out[0] = boxes_shape[0]
out[height_axis] = int64(crop_size[0])
out[width_axis] = int64(crop_size[1])
out[channel_axis] = image_shape[channel_axis]
return out

@reg.register_shape_func("image.crop_and_resize", False)
def crop_and_resize_func(attrs, inputs, _):
"""
Shape function for crop_and_resize op.
"""
layout = attrs.layout
height_axis = width_axis = channel_axis = 1
for i, letter in enumerate(layout):
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
if letter == "C":
channel_axis = i
crop_size = get_const_tuple(attrs.crop_size)
return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size),
convert(height_axis), convert(width_axis), convert(channel_axis))]


# dilation2d
reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,19 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
reg.register_broadcast_schedule("nn.mirror_pad")


@script
def _mirror_pad_func(data_shape, pad_width):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(data_shape.shape[0]):
out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1])
return out

@reg.register_shape_func("nn.mirror_pad", False)
def mirror_pad_func(attrs, inputs, _):
pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width]
return [_mirror_pad_func(inputs[0], convert(pad_width_tuple))]


# conv2d_winograd related operators
reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
strategy.conv2d_winograd_without_weight_transfrom_strategy)
Expand Down
61 changes: 61 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,64 @@ def test_mixed_input_type():
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))

def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
layout, static_boxes, static_box_indices_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
indices_dtype = "int32"
data = relay.var('data', shape=data_shape, dtype=dtype)
boxes = relay.var('boxes', shape=boxes_shape, dtype=dtype)
box_indices = relay.var('box_indices', shape=box_indices_shape, dtype=indices_dtype)
y = relay.image.crop_and_resize(data, boxes, box_indices, crop_size, layout)
mod["main"] = relay.Function([data, boxes, box_indices], y)
data_np = np.random.uniform(size=data_shape).astype(dtype)
boxes_np = np.random.uniform(size=static_boxes).astype(dtype)
box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np, boxes_np, box_indices_np)
tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape)

def test_any_crop_and_resize():
verify_any_crop_and_resize(
data_shape=(1, 234, 234, 256),
boxes_shape=(relay.Any(), 4),
box_indices_shape=(relay.Any(),),
crop_size=(14, 14),
layout='NHWC',
static_boxes=(128, 4),
static_box_indices_shape=(128,),
ref_out_shape=(128, 14, 14, 256))
verify_any_crop_and_resize(
data_shape=(1, 256, 234, 234),
boxes_shape=(relay.Any(), 4),
box_indices_shape=(relay.Any(),),
crop_size=(14, 14),
layout='NCHW',
static_boxes=(128, 4),
static_box_indices_shape=(128,),
ref_out_shape=(128, 256, 14, 14)
)

def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.nn.mirror_pad(data, pad_width)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape)

def test_any_mirror_pad():
verify_any_mirror_pad(
data_shape=(1, 256, 232, 232),
pad_width=((0, 0), (0, 0), (1, 1), (1, 1)),
static_data_shape=(1, 256, 232, 232),
ref_out_shape=(1, 256, 234, 234))

if __name__ == "__main__":
test_any_full()
test_any_full_like()
Expand Down Expand Up @@ -850,3 +908,6 @@ def test_mixed_input_type():
test_recursive_concat_with_wrong_annotation()
test_tuple_get_item()
test_mixed_input_type()
test_any_crop_and_resize()
test_any_mirror_pad()

0 comments on commit fe76196

Please sign in to comment.