diff --git a/docs/deploy/arm_compute_lib.rst b/docs/deploy/arm_compute_lib.rst index 26b42ae4a9c3..e3399c57db26 100644 --- a/docs/deploy/arm_compute_lib.rst +++ b/docs/deploy/arm_compute_lib.rst @@ -188,31 +188,50 @@ An example configuration for `test_config.json`: Operator support ---------------- -+--------------+-------------------------------------------------------------------------+ -| Relay Node | Remarks | -+==============+=========================================================================+ -| nn.conv2d | fp32: | -| | Simple: nn.conv2d | -| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu? | -| | | -| | (only groups = 1 supported) | -+--------------+-------------------------------------------------------------------------+ -| qnn.conv2d | uint8: | -| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize | -| | | -| | (only groups = 1 supported) | -+--------------+-------------------------------------------------------------------------+ -| nn.dense | fp32: | -| | Simple: nn.dense | -| | Composite: nn.dense, nn.bias_add? | -+--------------+-------------------------------------------------------------------------+ -| qnn.dense | uint8: | -| | Composite: qnn.dense, nn.bias_add?, qnn.requantize | -+--------------+-------------------------------------------------------------------------+ -| nn.maxpool2d | fp32, uint8 | -+--------------+-------------------------------------------------------------------------+ -| reshape | fp32, uint8 | -+--------------+-------------------------------------------------------------------------+ ++----------------------+-------------------------------------------------------------------------+ +| Relay Node | Remarks | ++======================+=========================================================================+ +| nn.conv2d | fp32: | +| | Simple: nn.conv2d | +| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu? | +| | | +| | (only groups = 1 supported) | ++----------------------+-------------------------------------------------------------------------+ +| qnn.conv2d | uint8: | +| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize | +| | | +| | (only groups = 1 supported) | ++----------------------+-------------------------------------------------------------------------+ +| nn.dense | fp32: | +| | Simple: nn.dense | +| | Composite: nn.dense, nn.bias_add? | ++----------------------+-------------------------------------------------------------------------+ +| qnn.dense | uint8: | +| | Composite: qnn.dense, nn.bias_add?, qnn.requantize | ++----------------------+-------------------------------------------------------------------------+ +| nn.max_pool2d | fp32, uint8 | ++----------------------+-------------------------------------------------------------------------+ +| nn.global_max_pool2d | fp32, uint8 | ++----------------------+-------------------------------------------------------------------------+ +| nn.avg_pool2d | fp32: | +| | Simple: nn.avg_pool2d | +| | | +| | uint8: | +| | Composite: cast(int32), nn.avg_pool2d, cast(uint8) | ++----------------------+-------------------------------------------------------------------------+ +| nn.global_avg_pool2d | fp32: | +| | Simple: nn.global_avg_pool2d | +| | | +| | uint8: | +| | Composite: cast(int32), nn.avg_pool2d, cast(uint8) | ++----------------------+-------------------------------------------------------------------------+ +| power(of 2) + | A special case for L2 pooling. | +| nn.avg_pool2d + | | +| sqrt | fp32: | +| | Composite: power(of 2), nn.avg_pool2d, sqrt | ++----------------------+-------------------------------------------------------------------------+ +| reshape | fp32, uint8 | ++----------------------+-------------------------------------------------------------------------+ .. note:: A composite operator is a series of operators that map to a single Arm Compute Library operator. You can view this diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index e20f2d191d03..adeeeb1edebb 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -17,10 +17,11 @@ # pylint: disable=invalid-name, unused-argument """Arm Compute Library supported operators.""" import tvm +from tvm.relay.expr import const from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from ...dataflow_pattern import wildcard, is_op, is_constant +from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr from .register import register_pattern_table @@ -125,6 +126,33 @@ def qnn_dense_pattern(): pattern, wildcard(), wildcard(), is_constant(), is_constant()) return pattern + def avg_pool2d_pattern(): + """Creates a pattern that matches either quantized + avg_pool2d or quantized global_avg_pool2d. + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the convolution pattern. + """ + pattern = is_op('cast')(wildcard()) + pattern = is_op('nn.avg_pool2d')(pattern) | is_op('nn.global_avg_pool2d')(pattern) + pattern = is_op('cast')(pattern) + return pattern + + def l2_pool2d_pattern(): + """Create an l2 pooling pattern from equivalent relay operators. + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the convolution pattern. + """ + pattern = is_op('power')(wildcard(), is_expr(const(2.0))) + pattern = is_op('nn.avg_pool2d')(pattern) + pattern = is_op('sqrt')(pattern) + return pattern + def check_conv(extract): """Check conv pattern is supported by ACL.""" call = extract @@ -157,10 +185,27 @@ def check_qnn_dense(extract): call = call.args[0] return qnn_dense(call.attrs, call.args) + def check_avg_pool2d(extract): + """Check average pool2d pattern is supported by ACL.""" + if extract.attrs.dtype != "uint8": + return False + pool = extract.args[0] + if pool.args[0].attrs.dtype != "int32": + return False + return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True) + + def check_l2_pool2d(extract): + """Check l2 pool2d pattern is supported by ACL.""" + pool = extract.args[0] + return avg_pool2d(pool.attrs, pool.args) + return [('arm_compute_lib.conv2d', conv_pattern(), check_conv), ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv), ('arm_compute_lib.dense', dense_pattern(), check_dense), - ('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense)] + ('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense), + ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv), + ('arm_compute_lib.avg_pool2d', avg_pool2d_pattern(), check_avg_pool2d), + ('arm_compute_lib.l2_pool2d', l2_pool2d_pattern(), check_l2_pool2d)] def _register_external_op_helper(op_name, supported=True): @@ -245,3 +290,40 @@ def max_pool2d(attrs, args): if typ.dtype not in ["float32", "uint8"]: return False return True + + +@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib") +def avg_pool2d(attrs, args, from_quantized_composite=False): + """Check if the external ACL codegen for avgpool2d should be used.""" + typ = args[0].checked_type + if from_quantized_composite: + if typ.dtype != "int32": + return False + else: + if typ.dtype not in ["float32"]: + return False + if attrs.layout != "NHWC": + return False + return True + + +@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib") +def global_max_pool2d(attrs, args): + """Check if the external ACL codegen for gloval_maxpool2d should be used.""" + typ = args[0].checked_type + if typ.dtype not in ["float32", "uint8"]: + return False + if attrs.layout != "NHWC": + return False + return True + + +@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib") +def global_avg_pool2d(attrs, args): + """Check if the external ACL codegen for global_avgpool2d should be used.""" + typ = args[0].checked_type + if typ.dtype not in ["float32"]: + return False + if attrs.layout != "NHWC": + return False + return True diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 1132b1c56cbc..087c895f4614 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -94,6 +94,10 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { json_node = CreateCompositeConvJSONNode(cn); } else if (name == "arm_compute_lib.dense" || name == "arm_compute_lib.qnn_dense") { json_node = CreateCompositeDenseJSONNode(cn); + } else if (name == "arm_compute_lib.avg_pool2d") { + json_node = CreateCompositeAvgPool2DJSONNode(cn); + } else if (name == "arm_compute_lib.l2_pool2d") { + json_node = CreateCompositeL2Pool2DJSONNode(cn); } else { LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name; } @@ -267,6 +271,62 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { SetCallNodeAttribute(json_node, nodes.dense); return json_node; } + + /*! + * \brief Create a JSON representation of a composite (global) average pooling operator. + * + * A composite function is only created when using the uint8 datatype for these operators. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeAvgPool2DJSONNode(const CallNode* cn) { + const auto* fn = cn->op.as(); + CHECK(fn); + const auto* cast = fn->body.as(); + CHECK(cast); + const auto* avg_pool = cast->args[0].as(); + CHECK(avg_pool); + const auto* avg_pool_op = avg_pool->op.as(); + CHECK(avg_pool_op); + const std::string name = avg_pool_op->name; + + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, avg_pool); + return json_node; + } + + /*! + * \brief Create a JSON representation of a composite L2 pooling operator. + * + * \note Relay does not have an operator for L2 pooling, instead we can create + * an equivalent from power(2) + nn.avg_pool2d + sqrt. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeL2Pool2DJSONNode(const CallNode* cn) { + const std::string name = "nn.l2_pool2d"; + const auto* fn = cn->op.as(); + CHECK(fn); + const auto* sqrt = fn->body.as(); + CHECK(sqrt); + const auto* avg_pool = sqrt->args[0].as(); + CHECK(avg_pool); + const auto* pow = avg_pool->args[0].as(); + CHECK(pow); + const auto* exponent = pow->args[1].as(); + CHECK(exponent); + CHECK_EQ(*static_cast(exponent->data->data), 2) << "Exponent must be 2 for L2 pooling"; + + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, avg_pool); + return json_node; + } }; /*! diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index f62420a3684f..f2d2fca64055 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -132,8 +132,11 @@ class ACLRuntime : public JSONRuntimeBase { } else if ("nn.dense" == op_name || "qnn.dense" == op_name) { CreateFullyConnectedLayer(&layer_, node, mm); num_pools++; - } else if ("nn.max_pool2d" == op_name) { + } else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name || + "nn.l2_pool2d" == op_name) { CreatePoolingLayer(&layer_, node); + } else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name) { + CreateGlobalPoolingLayer(&layer_, node); } else if ("reshape" == op_name) { CreateReshapeLayer(&layer_, node); } else { @@ -308,7 +311,7 @@ class ACLRuntime : public JSONRuntimeBase { /*! * \brief Create a pooling layer. * - * \note Currently only maxpool is supported. + * \note Currently max_pool2d, avg_pool2d and L2 pooling are supported. * * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. * \param node The JSON representation of the operator. @@ -316,22 +319,65 @@ class ACLRuntime : public JSONRuntimeBase { void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) { std::vector padding = node.GetAttr>("padding"); std::vector strides = node.GetAttr>("strides"); - arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides); + bool ceil_mode = std::stoi(node.GetAttr>("ceil_mode")[0]); + arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides, ceil_mode); auto attr_pool_size = node.GetAttr>("pool_size"); int pool_size_h = std::stoi(attr_pool_size[0]); int pool_size_w = std::stoi(attr_pool_size[1]); + // Only applies to average pool and l2 pool. + // ACL exclude pad option is inverse to Relays include pad option. + bool exclude_pad = false; + if (node.HasAttr("count_include_pad")) { + int count_include_pad = + std::stoi(node.GetAttr>("count_include_pad")[0]); + exclude_pad = !count_include_pad; + } + arm_compute::PoolingType pool_type; if (node.GetOpName() == "nn.max_pool2d") { pool_type = arm_compute::PoolingType::MAX; + } else if (node.GetOpName() == "nn.avg_pool2d") { + pool_type = arm_compute::PoolingType::AVG; + } else if (node.GetOpName() == "nn.l2_pool2d") { + pool_type = arm_compute::PoolingType::L2; } else { LOG(FATAL) << "Pooling type not supported"; } arm_compute::PoolingLayerInfo pool_info = arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w), - arm_compute::DataLayout::NHWC, pad_stride_info); + arm_compute::DataLayout::NHWC, pad_stride_info, exclude_pad); + + layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0])); + layer->outputs.push_back(MakeACLTensorFromJSONNode(node)); + + auto function = std::make_shared(); + function->configure(&layer->inputs[0], &layer->outputs[0], pool_info); + layer->function = function; + } + + /*! + * \brief Create a global pooling layer. + * + * \note Currently global_max_pool2d and global_avg_pool2d are supported. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. + * \param node The JSON representation of the operator. + */ + void CreateGlobalPoolingLayer(CachedLayer* layer, const JSONGraphNode& node) { + arm_compute::PoolingType pool_type; + if (node.GetOpName() == "nn.global_max_pool2d") { + pool_type = arm_compute::PoolingType::MAX; + } else if (node.GetOpName() == "nn.global_avg_pool2d") { + pool_type = arm_compute::PoolingType::AVG; + } else { + LOG(FATAL) << "Pooling type not supported"; + } + + arm_compute::PoolingLayerInfo pool_info = + arm_compute::PoolingLayerInfo(pool_type, arm_compute::DataLayout::NHWC); layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0])); layer->outputs.push_back(MakeACLTensorFromJSONNode(node)); diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc index 98c9cda9fae7..59c941df5195 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -81,9 +81,11 @@ std::shared_ptr MakeACLMemoryManager() { } arm_compute::PadStrideInfo MakeACLPadStride(const std::vector& pad, - const std::vector& stride) { + const std::vector& stride, + bool ceil_mode) { int pad_0 = 0, pad_1 = 0, pad_2 = 0, pad_3 = 0; int stride_0 = std::stoi(stride[0]), stride_1 = std::stoi(stride[1]); + auto dimensions_rounding = arm_compute::DimensionRoundingType::FLOOR; size_t size = pad.size(); if (size == 1) { int pad_v = std::stoi(pad[0]); @@ -109,8 +111,12 @@ arm_compute::PadStrideInfo MakeACLPadStride(const std::vector& pad, LOG(FATAL) << "Unsupported padding dimensions"; } + if (ceil_mode) { + dimensions_rounding = arm_compute::DimensionRoundingType::CEIL; + } + return arm_compute::PadStrideInfo(stride_0, stride_1, pad_0, pad_1, pad_2, pad_3, - arm_compute::DimensionRoundingType::FLOOR); + dimensions_rounding); } arm_compute::DataType MakeACLDataType(const DLDataType& data_type) { diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.h b/src/runtime/contrib/arm_compute_lib/acl_utils.h index 80c6f0bcd958..576ed916ff60 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.h +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.h @@ -93,10 +93,12 @@ std::shared_ptr MakeACLMemoryManager(); * * \param pad The pad vector. * \param stride The stride vector. + * \param ceil_mode Dimensions rounding. * \return arm_compute::PadStrideInfo */ arm_compute::PadStrideInfo MakeACLPadStride(const std::vector& pad, - const std::vector& stride); + const std::vector& stride, + bool ceil_mode = false); /*! * \brief Convert DLDataType to arm_compute::DataType. diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index 4e930e2276ee..cc4818e96625 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -181,9 +181,20 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs=1, - tvm_ops=0, acl_partitions=1): + tvm_ops=0, acl_partitions=1, config=None): """Build and run the relay module.""" - lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions) + if config is None: + config = {} + + try: + lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions) + except Exception as e: + err_msg = "The module could not be built.\n" + if config: + err_msg += f"The test failed with the following parameters: {config}\n" + err_msg += str(e) + raise Exception(err_msg) + lib = update_lib(lib, device.device, device.cross_compile) gen_module = graph_runtime.GraphModule(lib['default'](device.device.cpu(0))) gen_module.set_input(**inputs) @@ -208,28 +219,28 @@ def update_lib(lib, device, cross_compile): return lib -def verify(answers, atol, rtol, verify_saturation=False, params=None): +def verify(answers, atol, rtol, verify_saturation=False, config=None): """Compare the array of answers. Each entry is a list of outputs.""" - if params is None: - params = {} + if config is None: + config = {} if len(answers) < 2: raise RuntimeError( f"No results to compare: expected at least two, found {len(answers)}") for answer in zip_longest(*answers): for outs in combinations(answer, 2): - if verify_saturation: - assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \ - "Output is saturated: {}".format(outs[0]) - assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \ - "Output is saturated: {}".format(outs[0]) try: + if verify_saturation: + assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \ + "Output is saturated: {}".format(outs[0]) + assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \ + "Output is saturated: {}".format(outs[0]) tvm.testing.assert_allclose( outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol) except AssertionError as e: err_msg = "Results not within the acceptable tolerance.\n" - if params: - err_msg += f"The test failed with the following parameters: {params}\n" + if config: + err_msg += f"The test failed with the following parameters: {config}\n" err_msg += str(e) raise AssertionError(err_msg) diff --git a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py index 555cbe193408..37575cccf9eb 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py +++ b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py @@ -276,7 +276,7 @@ def test_conv2d(): params, device, enable_acl=acl)[0]) - params = { + config = { "shape": shape, "groups": groups, "kernel size": (kernel_h, kernel_w), @@ -286,7 +286,7 @@ def test_conv2d(): "out channels": out_channels, "composite operators (pad, bias, activation)": composite } - verify(outputs, atol=0.002, rtol=0.01, params=params) + verify(outputs, atol=0.002, rtol=0.01, config=config) def test_codegen_conv2d(): @@ -380,7 +380,7 @@ def test_qnn_conv2d(): params, device, enable_acl=acl)[0]) - params = { + config = { "shape": shape, "groups": groups, "kernel size": (kernel_h, kernel_w), @@ -396,15 +396,13 @@ def test_qnn_conv2d(): "output scale": output_sc, "output zero point": output_zp } - verify(outputs, atol=1, rtol=0, params=params, verify_saturation=True) + verify(outputs, atol=1, rtol=0, config=config, verify_saturation=True) def test_codegen_qnn_conv2d(): if skip_codegen_test(): return - np.random.seed(0) - kernel_hs = [1, 2, 3, 5] kernel_ws = [1, 2, 3, 5] pad = [(1, 1), (2, 2), (2, 1)] diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index 18cac3380315..e1bb83b52079 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -116,7 +116,7 @@ def get_model(): return mod, params, inputs _build_and_run_network(*get_model(), device=device, - tvm_ops=74, acl_partitions=17, + tvm_ops=73, acl_partitions=18, atol=0.002, rtol=0.01) @@ -144,7 +144,7 @@ def get_model(): return mod, params, inputs _build_and_run_network(*get_model(), device=device, - tvm_ops=45, acl_partitions=16, + tvm_ops=42, acl_partitions=17, atol=8, rtol=0) diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index 32176afd1346..c104a0659b7f 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -26,26 +26,70 @@ from .infrastructure import Device -def _get_model(shape, dtype, typef, sizes, strides, padding, - ceil_mode, var_names): +def _calculate_output_shape(shape, sizes, padding, strides): + """Calculate pooling output shape.""" + output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[2] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1 + return 1, int(output_height), int(output_width), shape[3] + + +def _get_pooling_model(shape, dtype, typef, sizes, strides, padding, + ceil_mode, count_include_pad, var_names): + """Return a model and any parameters it may have.""" + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + out = relay.var(next(var_names), shape=shape, dtype=dtype) + + if typef == "nn.max_pool2d": + out = relay.nn.max_pool2d(out, pool_size=sizes, strides=strides, padding=padding, + ceil_mode=ceil_mode, layout="NHWC") + elif typef == "nn.avg_pool2d": + if dtype == "uint8": + out = relay.cast(out, 'int32') + out = relay.nn.avg_pool2d(out, pool_size=sizes, strides=strides, padding=padding, + ceil_mode=ceil_mode, count_include_pad=count_include_pad, + layout="NHWC") + if dtype == "uint8": + out = relay.cast(out, 'uint8') + elif typef == "nn.l2_pool2d": + out = relay.power(out, relay.const(2.0)) + out = relay.nn.avg_pool2d(out, pool_size=sizes, strides=strides, padding=padding, + ceil_mode=ceil_mode, count_include_pad=count_include_pad, + layout="NHWC") + out = relay.sqrt(out) + else: + raise ValueError("Function not supported") + + return out + + +def _get_global_pooling_model(shape, dtype, typef, var_names): """Return a model and any parameters it may have.""" - var = relay.var(next(var_names), shape=shape, dtype=dtype) - pool = typef(var, pool_size=sizes, strides=strides, padding=padding, - ceil_mode=ceil_mode, layout="NHWC") - return pool + out = relay.var(next(var_names), shape=shape, dtype=dtype) + + if typef == "nn.global_max_pool2d": + out = relay.nn.global_max_pool2d(out, layout="NHWC") + elif typef == "nn.global_avg_pool2d": + if dtype == "uint8": + out = relay.cast(out, 'int32') + out = relay.nn.global_avg_pool2d(out, layout="NHWC") + if dtype == "uint8": + out = relay.cast(out, 'uint8') + else: + raise ValueError("Function not supported") + + return out -def _get_expected_codegen(shape, dtype, typef, sizes, strides, - padding, ceil_mode): +def _get_expected_pooling_codegen(shape, dtype, typef, sizes, strides, + padding, ceil_mode, count_include_pad): if len(padding) == 2: - padding = (padding[1], padding[1], padding[0], padding[0]) - output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1 - output_width = ((shape[2] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1 - output_shape = (1, int(output_height), int(output_width), shape[3]) + padding = (padding[0], padding[1], padding[0], padding[1]) + output_shape = _calculate_output_shape(shape, sizes, padding, strides) node = { "op": "kernel", - "name": "nn.max_pool2d", + "name": typef, "inputs": [[0, 0, 0]], "attrs": { "num_inputs": "1", @@ -60,6 +104,30 @@ def _get_expected_codegen(shape, dtype, typef, sizes, strides, }, } + if typef == "nn.avg_pool2d" or typef == "nn.l2_pool2d": + node["attrs"]["count_include_pad"] = [["1" if count_include_pad else "0"]] + + input = { + "op": "input", + "name": "", + "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}} + return [input, node] + + +def _get_expected_global_pooling_codegen(shape, dtype, typef): + node = { + "op": "kernel", + "name": typef, + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "layout": [["NHWC"]], + "shape": [[[1, 1, 1, shape[3]]]], + "dtype": [[dtype]] + } + } + input = { "op": "input", "name": "", @@ -76,53 +144,160 @@ def test_pooling(): device = Device() np.random.seed(0) - for dtype, low, high, atol, rtol in [("float32", -127, 128, 0.001, 0.001), ("uint8", 0, 255, 0, 0)]: - for size in [(2, 2), (3, 3)]: - for stride in [(2, 2)]: - shape = (1, size[0] + stride[0] * 5, - size[1] + stride[1] * 5, 16) - pad = (0, 0) - - inputs = { - "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)), - } - - outputs = [] - func = _get_model(shape, dtype, relay.nn.max_pool2d, size, - stride, pad, True, iter(inputs)) - for acl in [False, True]: - outputs.append(build_and_run(func, inputs, 1, None, device, - enable_acl=acl)[0]) - - params = { - "size": size, - "stride": stride, - "shape": shape, - "pooling type": "max", - "dtype": dtype, - "padding": pad - } - verify(outputs, atol=atol, rtol=rtol, params=params, verify_saturation=True) + fp32_dtype = ("float32", -127, 128, 0.001, 0.001) + uint8_dtype = ("uint8", 0, 255, 1, 0) + + trials = [["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)]] + + for typef, (dtype, low, high, atol, rtol), size, stride, pad, ceil_mode, count_include_pad, \ + input_shape in trials: + shape = (1, *input_shape) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)), + } + + func = _get_pooling_model(shape, dtype, typef, size, + stride, pad, ceil_mode, count_include_pad, iter(inputs)) + + config = { + "size": size, + "stride": stride, + "shape": shape, + "pooling type": typef, + "dtype": dtype, + "padding": pad, + "ceil_mode": ceil_mode, + "count_include_pad": count_include_pad + } + verify_saturation = True if dtype == "uint8" else False + + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, None, device, + enable_acl=acl, config=config)[0]) + + verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation) + + +def test_global_pooling(): + Device.load("test_config.json") + + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + fp32_dtype = ("float32", -127, 128, 0.001, 0.001) + uint8_dtype = ("uint8", 0, 255, 1, 0) + + trials = [["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)], + ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)], + ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)], + ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)], + ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)], + ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)]] + + for typef, (dtype, low, high, atol, rtol), input_shape in trials: + shape = (1, *input_shape) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)), + } + + func = _get_global_pooling_model(shape, dtype, typef, iter(inputs)) + + config = { + "shape": shape, + "pooling type": typef, + "dtype": dtype, + } + verify_saturation = True if dtype == "uint8" else False + + for acl in [False, True]: + outputs.append(build_and_run(func, inputs, 1, None, device, + enable_acl=acl, config=config)[0]) + + verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation) def test_codegen_pooling(): if skip_codegen_test(): return - inputs = {"a"} + fp32_dtype = ("float32", -127, 128) + uint8_dtype = ("uint8", 0, 255) + + trials = [["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (15, 15, 16)], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)]] + + for typef, (dtype, low, high), size, stride, pad, ceil_mode, count_include_pad, \ + input_shape in trials: + shape = (1, *input_shape) + inputs = {"a"} + args = (shape, dtype, typef, size, + stride, pad, False, False) + func = _get_pooling_model(*args, iter(inputs)) + exp_codegen = _get_expected_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +def test_codegen_global_pooling(): + if skip_codegen_test(): + return + + fp32_dtype = ("float32", -127, 128) + uint8_dtype = ("uint8", 0, 255) + + trials = [["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)], + ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)], + ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)], + ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)], + ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)], + ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)], + ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)]] - for dtype in ["float32", "uint8"]: - for size in [(2, 2), (3, 3)]: - for stride in [(2, 2)]: - shape = (1, size[0] + stride[0] * 5, - size[1] + stride[1] * 5, 16) - args = (shape, dtype, relay.nn.max_pool2d, size, - stride, (0, 0), True) - func = _get_model(*args, iter(inputs)) - exp_codegen = _get_expected_codegen(*args) - verify_codegen(func, exp_codegen, 1) + for typef, (dtype, low, high), input_shape in trials: + shape = (1, *input_shape) + inputs = {"a"} + args = (shape, dtype, typef) + func = _get_global_pooling_model(*args, iter(inputs)) + exp_codegen = _get_expected_global_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) if __name__ == "__main__": test_pooling() + test_global_pooling() test_codegen_pooling() + test_codegen_global_pooling() diff --git a/tests/python/contrib/test_arm_compute_lib/test_reshape.py b/tests/python/contrib/test_arm_compute_lib/test_reshape.py index 38694e8ccaaa..b6a87542062a 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_reshape.py +++ b/tests/python/contrib/test_arm_compute_lib/test_reshape.py @@ -78,12 +78,12 @@ def test_reshape(): outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0]) - params = { + config = { "new shape": inputs["a"].shape, "shape": new_shape, "dtype": dtype, } - verify(outputs, atol=1e-7, rtol=1e-7, params=params) + verify(outputs, atol=1e-7, rtol=1e-7, config=config) def test_codegen_reshape():