diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index dd3a0aa0cc65..87ad82d0293f 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -36,7 +36,7 @@ struct ResizeAttrs : public tvm::AttrsNode { Array size; std::string layout; std::string method; - bool align_corners; + std::string coordinate_transformation_mode; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -52,8 +52,11 @@ struct ResizeAttrs : public tvm::AttrsNode { "nearest_neighbor - Nearest Neighbor" "bilinear - Bilinear Interpolation" "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(align_corners).set_default(true) - .describe("Should be true to preserve the values at the corner pixels"); + TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") + .describe("Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type."); diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index a1a357883a83..1f85277712aa 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -676,7 +676,7 @@ def _mx_resize(inputs, attrs): if scale_width is not None: width = (scale_width * shape[3]).astype("int32") size = (height, width) - return _op.image.resize(inputs[0], size, align_corners=True) + return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") def _mx_roi_pooling(inputs, attrs): new_attrs = {} diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c7764db729ee..4809100f3c2c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1091,6 +1091,7 @@ class Or(Elemwise): def _impl_v7(cls, inputs, attr, params): return _op.logical_or(inputs[0], inputs[1]) + class Expand(OnnxOpConverter): """ Operator converter for Expand. """ @@ -1138,6 +1139,44 @@ def expand_shape(in_shape, shape): shape = expand_shape(in_shape, shape) return _op.broadcast_to(inputs[0], shape=tuple(shape)) + +class Resize(OnnxOpConverter): + """Operator converter for Resize + """ + @classmethod + def _impl_v11(cls, inputs, attr, params): + mode = attr.get('mode') + if mode == b'nearest': + method = "nearest_neighbor" + elif mode == b'linear': + method = "bilinear" + else: + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)) + + in_size = np.array(infer_shape(inputs[0])) + scale = infer_value_simulated(inputs[2], params).asnumpy() + if len(inputs) == 4: + assert len(scale) == 0, "One of scale or size should be passed, not both." + size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32) + else: + assert len(scale) != 0, "One of scale or size should be passed." + size = (in_size * scale).astype(np.int32) + + coord_trans = attr.get('coordinate_transformation_mode') + if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: + coord_trans = "half_pixel" + elif coord_trans == b'align_corners': + coord_trans = "align_corners" + elif coord_trans == b'asymmetric' or method == "nearest_neighbor": + coord_trans = "asymmetric" + else: + raise tvm.error.OpAttributeInvalid( + 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) + layout = "NCHW" # ONNX assumes NCHW layout + out_size = (size[2], size[3]) + return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1263,7 +1302,8 @@ def _get_convert_map(opset): 'Tile': Tile.get_converter(opset), 'Erf': Erf.get_converter(opset), 'Where': Where.get_converter(opset), - 'Or': Or.get_converter(opset) + 'Or': Or.get_converter(opset), + 'Resize': Resize.get_converter(opset), } diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index db037e49bded..8a6e5b778283 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -582,7 +582,7 @@ def _impl(inputs, attr, params): raise tvm.error.OpAttributeUnImplemented( 'Attribute method=nearest is not supported') else: - attrs['align_corners'] = True + attrs['coordinate_transformation_mode'] = 'align_corners' attrs['method'] = 'bilinear' out = None @@ -632,6 +632,10 @@ def _impl(inputs, attr, params): inputs.pop(1) # NHWC attr['layout'] = 'NHWC' + if attr.pop('align_corners') is True: + attr['coordinate_transformation_mode'] = 'align_corners' + else: + attr['coordinate_transformation_mode'] = 'asymmetric' # Ignore the new attributes from TF2.0, for now. return AttrCvt(op_name='resize', diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index be18bf622196..e2e01e545340 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -330,7 +330,9 @@ def _convert_resize(self, method, op): align_corners = resize_options.AlignCorners() # Use layout NHWC - out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners) + coord_trans = "align_corners" if align_corners else "asymmetric" + out = _op.image.resize(in_expr, target_size, "NHWC", method, + coordinate_transformation_mode=coord_trans) return out def convert_resize_bilinear(self, op): diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index fcebfd8c9613..776435ada497 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -31,6 +31,6 @@ def compute_resize(attrs, inputs, out_type, target): size = attrs.size layout = attrs.layout method = attrs.method - align_corners = attrs.align_corners + coord_trans = attrs.coordinate_transformation_mode out_dtype = attrs.out_dtype - return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)] + return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index c54e438dce51..e0475a06025a 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -22,7 +22,7 @@ def resize(data, size, layout="NCHW", method="bilinear", - align_corners=True, + coordinate_transformation_mode="half_pixel", out_dtype=None): """Image resize operator. @@ -48,8 +48,11 @@ def resize(data, method : str, optional Scale method to used [nearest_neighbor, bilinear, bicubic]. - align_corners : int, optional - Should be true to preserve the values at the corner pixels + coordinate_transformation_mode : string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + [half_pixel, align_corners, asymmetric] out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -59,4 +62,4 @@ def resize(data, result: relay.Expr The resized result. """ - return _make.resize(data, size, layout, method, align_corners, out_dtype) + return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index f6329f7af709..baab0ead692f 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -71,13 +71,13 @@ Expr MakeResize(Expr data, Array size, std::string layout, std::string method, - bool align_corners, + std::string coordinate_transformation_mode, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); - attrs->align_corners = align_corners; + attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a35ebd23ae0a..d915acb439c6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -98,23 +98,6 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) -def verify_super_resolution_example(): - verify_onnx_forward_impl( - super_resolution, (1, 1, 224, 224), (1, 1, 672, 672)) - - -def verify_squeezenet1_1(): - verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000)) - - -def verify_lenet(): - verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10)) - - -def verify_resnet18(): - verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000)) - - def test_reshape(): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -1844,6 +1827,62 @@ def forward(self, input): relay.frontend.from_onnx(onnx_model, {'0': input_size}) +def test_resize(): + def make_constant_node(name, data_type, dims, vals): + return helper.make_node('Constant', + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, + data_type=data_type, + dims=dims, + vals=vals)) + def verify(ishape, oshape, scales, mode, coord_trans): + nodes = [ + make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []), + make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales) + ] + input_names = ['X', 'roi', 'scales'] + if oshape != []: + nodes.append(make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape)) + input_names.append('sizes') + nodes.append(helper.make_node( + 'Resize', + inputs=input_names, + outputs=['Y'], + mode=mode, + coordinate_transformation_mode=coord_trans + )) + + if oshape == []: + oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] + + graph = helper.make_graph(nodes, + "resize_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)]) + + model = helper.make_model(graph, producer_name='resize_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=ishape).astype('float32') + onnx_out = get_onnxruntime_output(model, x, 'float32') + tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) + + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + + # # NCHW and upsampling + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") + # NCHW and downsampling + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") + # scales are specified instead of sizes + verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") + verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1901,3 +1940,4 @@ def forward(self, input): test_conv() test_convtranspose() test_unsqueeze_constant() + test_resize() diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 84e9f55d67e7..2f2e8523161c 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -39,7 +39,7 @@ def test_resize_infer_type(): assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True) + z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @@ -57,7 +57,7 @@ def verify_resize(dshape, scale, method, layout): else: ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.image.resize(x, size, layout, method, True) + z = relay.image.resize(x, size, layout, method, "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 27bea9434348..004e04a604e5 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -21,7 +21,8 @@ from .. import tag -def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None): +def resize(data, size, layout="NCHW", method="bilinear", + coordinate_transformation_mode="half_pixel", out_dtype=None): """Perform resize operation on the data. Parameters @@ -37,8 +38,11 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out layout: string, optional "NCHW", "NHWC", or "NCHWc". - align_corners: Boolean, optional - To preserve the values at the corner pixels. + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". method: {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for resizing. @@ -66,12 +70,15 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out in_n, in_c, in_h, in_w, in_cc = data.shape output_shape = [in_n, in_c, size[0], size[1], in_cc] - if align_corners: + if coordinate_transformation_mode == "align_corners": y_ratio = (in_h - 1).astype('float') / (size[0] - 1) x_ratio = (in_w - 1).astype('float') / (size[1] - 1) - else: + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: y_ratio = (in_h).astype('float') / (size[0]) x_ratio = (in_w).astype('float') / (size[1]) + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) def _get_pixel(n, c, y, x, cc): y = tvm.max(tvm.min(y, in_h - 1), 0) @@ -109,7 +116,7 @@ def _nearest_neighbor(*indices): in_y = y_ratio * y in_x = x_ratio * x - if align_corners: + if coordinate_transformation_mode == "align_corners": yint = tvm.round(in_y).astype('int32') xint = tvm.round(in_x).astype('int32') else: @@ -127,8 +134,12 @@ def _lerp(A, B, t): def _bilinear(*indices): n, c, y, x, cc = _get_indices(*indices) - in_y = y_ratio * y - in_x = x_ratio * x + if coordinate_transformation_mode == "half_pixel": + in_y = y_ratio * (y + 0.5) - 0.5 + in_x = x_ratio * (x + 0.5) - 0.5 + else: + in_y = y_ratio * y + in_x = x_ratio * x xint = tvm.floor(in_x).astype('int32') xfract = in_x - tvm.floor(in_x) @@ -158,8 +169,12 @@ def _cubic_kernel(A, B, C, D, t): def _bicubic(*indices): n, c, y, x, cc = _get_indices(*indices) - in_y = y_ratio * y - in_x = x_ratio * x + if coordinate_transformation_mode == "half_pixel": + in_y = y_ratio * (y + 0.5) - 0.5 + in_x = x_ratio * (x + 0.5) - 0.5 + else: + in_y = y_ratio * y + in_x = x_ratio * x xint = tvm.floor(in_x).astype('int32') xfract = in_x - tvm.floor(in_x) diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index fe63e474f2bf..c816bbb3c04e 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -61,8 +61,9 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', else: raise ValueError("not support this layout {} yet".format(layout)) + coord_trans = "align_corners" if align_corners else "asymmetric" return topi.image.resize(data, out_shape, layout=layout, - method=method, align_corners=align_corners) + method=method, coordinate_transformation_mode=coord_trans) def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', diff --git a/topi/python/topi/testing/bilinear_resize_python.py b/topi/python/topi/testing/bilinear_resize_python.py index 86dd450a88e2..d324e2900c4f 100644 --- a/topi/python/topi/testing/bilinear_resize_python.py +++ b/topi/python/topi/testing/bilinear_resize_python.py @@ -19,7 +19,7 @@ import math import numpy as np -def bilinear_resize_python(image, out_size, layout, align_corners=True): +def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): """ Bilinear scaling using python""" (new_h, new_w) = out_size @@ -30,32 +30,37 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): (batch, channel, h, w) = image.shape scaled_image = np.ones((batch, channel, new_h, new_w)) - if align_corners: + if coordinate_transformation_mode == "align_corners": height_scale = np.float32(h-1) / np.float32(out_size[0]-1) width_scale = np.float32(w-1) / np.float32(out_size[1]-1) else: height_scale = np.float32(h) / np.float32(out_size[0]) width_scale = np.float32(w) / np.float32(out_size[1]) + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + for b in range(batch): for i in range(channel): for j in range(new_h): for k in range(new_w): - in_y = j * height_scale - y0 = math.floor(in_y) - y1 = min(math.ceil(in_y), h - 1) - y_lerp = in_y - y0 - - y0 = int(y0) - y1 = int(y1) - - in_x = k * width_scale - x0 = math.floor(in_x) - x1 = min(math.ceil(in_x), w - 1) - x_lerp = in_x - x0 + if coordinate_transformation_mode == "half_pixel": + in_y = (j + 0.5) * height_scale - 0.5 + else: + in_y = j * height_scale + y0 = int(math.floor(in_y)) + y1 = max(min(y0 + 1, h - 1), 0) + y0 = max(y0, 0) + y_lerp = in_y - math.floor(in_y) - x0 = int(x0) - x1 = int(x1) + if coordinate_transformation_mode == "half_pixel": + in_x = (k + 0.5) * width_scale - 0.5 + else: + in_x = k * width_scale + x0 = int(math.floor(in_x)) + x1 = max(min(x0 + 1, w - 1), 0) + x0 = max(x0, 0) + x_lerp = in_x - math.floor(in_x) if layout == 'NHWC': A = image[b][y0][x0][i] @@ -68,10 +73,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): C = image[b][i][y1][x0] D = image[b][i][y1][x1] - top = A + (B - A) * x_lerp - bottom = C + (D - C) * x_lerp + top = _lerp(A, B, x_lerp) + bottom = _lerp(C, D, x_lerp) - pixel = np.float32(top + (bottom - top) * y_lerp) + pixel = np.float32(_lerp(top, bottom, y_lerp)) if layout == 'NHWC': scaled_image[b][j][k][i] = pixel diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 10678a0c2600..206903ff1dc1 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -23,7 +23,8 @@ from common import get_all_backend -def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"): +def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, + layout='NCHW', coord_trans="align_corners", method="bilinear"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') dtype = A.dtype @@ -37,11 +38,9 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, else: raise NotImplementedError( 'Layout not supported {} '.format(layout)) - - B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) - + B = topi.image.resize(A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method) if method == "bilinear": - b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) + b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, coord_trans) else: scale_h = out_height / in_height scale_w = out_width / in_width @@ -70,14 +69,17 @@ def test_resize(): # Scale NCHW verify_resize(4, 16, 32, 32, 50, 50, 'NCHW') # Scale NCHW + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True) + verify_resize(6, 32, 64, 64, 20, 20, 'NCHW') # Scale NHWC verify_resize(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) + verify_resize(6, 32, 64, 64, 20, 20, "NHWC") # Nearest + Fractional - verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False) - verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False) + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', "asymmetric", method="nearest_neighbor") + verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', "asymmetric", method="nearest_neighbor") + # half_pixel + verify_resize(4, 16, 16, 16, 32, 32, 'NCHW', "half_pixel", method="bilinear") + verify_resize(4, 16, 16, 16, 32, 32, 'NHWC', "half_pixel", method="bilinear") def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width, diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index f5b77b1190a6..3aa67a5f78a4 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -43,7 +43,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, if method == "bilinear": out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w))) - b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) + b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric") else: b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout)