From 6562eee0f0880c7539c85d06cd0bcbd8785556d3 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 17 Aug 2020 17:29:28 -0700 Subject: [PATCH 1/4] fix lint --- python/tvm/relay/op/dyn/image/_image.py | 37 ++++++++++++------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index fa528e9a202d..1f7c0ea69f5a 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -40,24 +40,14 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("dyn.image.resize") - -@script -def _NCHW_resize_shape_func(dshape, size, ndim): - out = output_tensor((ndim, ), "int64") - for i in const_range(ndim): - out[i] = int64(dshape[i]) - out[2] = int64(size[0]) - out[3] = int64(size[1]) - return out - - @script -def _NHWC_resize_shape_func(dshape, size, ndim): +def _resize_shape_func(dshape, size, ndim, height_axis, width_axis, channel_axis): out = output_tensor((ndim, ), "int64") for i in const_range(ndim): out[i] = int64(dshape[i]) - out[1] = int64(size[0]) - out[2] = int64(size[1]) + out[height_axis] = int64(size[0]) + out[width_axis] = int64(size[1]) + out[channel_axis] = int64(dshape[channel_axis]) return out @@ -67,10 +57,19 @@ def resize_shape_func(attrs, inputs, _): Shape function for dyn.image.resize op. """ layout = attrs.layout - if layout == 'NHWC': - out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))] - elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout): - out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))] + if nchw_pack_layout(layout) or nchw_xc_layout(layout): + out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), + convert(2), convert(3), convert(1))] else: - raise ValueError("Resize Unsupported Layout", layout) + height_axis = width_axis = channel_axis = 1 + for i, letter in enumerate(layout): + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), + convert(height_axis), convert(widht_axis), + convert(channel_axis))] return out From be45b3adebdb5b009b0c8f4cdce8632b43385264 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 17 Aug 2020 17:42:51 -0700 Subject: [PATCH 2/4] fix typo --- python/tvm/relay/op/dyn/image/_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 1f7c0ea69f5a..8ebd48b7dca6 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -70,6 +70,6 @@ def resize_shape_func(attrs, inputs, _): if letter == "C": channel_axis = i out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), - convert(height_axis), convert(widht_axis), + convert(height_axis), convert(width_axis), convert(channel_axis))] return out From 273ed90cdd475fd7baa8ac74a6250f76b917c153 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 24 Aug 2020 18:05:31 -0700 Subject: [PATCH 3/4] remove channel_axis from resize shape func --- python/tvm/relay/op/dyn/image/_image.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 8ebd48b7dca6..73aafc5abd62 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -41,16 +41,14 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("dyn.image.resize") @script -def _resize_shape_func(dshape, size, ndim, height_axis, width_axis, channel_axis): +def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): out = output_tensor((ndim, ), "int64") for i in const_range(ndim): out[i] = int64(dshape[i]) out[height_axis] = int64(size[0]) out[width_axis] = int64(size[1]) - out[channel_axis] = int64(dshape[channel_axis]) return out - @reg.register_shape_func("dyn.image.resize", True) def resize_shape_func(attrs, inputs, _): """ @@ -59,7 +57,7 @@ def resize_shape_func(attrs, inputs, _): layout = attrs.layout if nchw_pack_layout(layout) or nchw_xc_layout(layout): out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), - convert(2), convert(3), convert(1))] + convert(2), convert(3))] else: height_axis = width_axis = channel_axis = 1 for i, letter in enumerate(layout): @@ -67,9 +65,6 @@ def resize_shape_func(attrs, inputs, _): height_axis = i if letter == "W": width_axis = i - if letter == "C": - channel_axis = i out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), - convert(height_axis), convert(width_axis), - convert(channel_axis))] + convert(height_axis), convert(width_axis))] return out From a8d4cf013d2a58bc85911eee7694529fc153fc7d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 25 Aug 2020 09:49:37 -0700 Subject: [PATCH 4/4] fix lint --- python/tvm/relay/op/dyn/image/_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 73aafc5abd62..2d36708af04f 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -59,7 +59,7 @@ def resize_shape_func(attrs, inputs, _): out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), convert(2), convert(3))] else: - height_axis = width_axis = channel_axis = 1 + height_axis = width_axis = 1 for i, letter in enumerate(layout): if letter == "H": height_axis = i