diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 9f3c1cdec0f70..9abd320b2956f 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument, dangerous-default-value """Arm Compute Library supported operators.""" import tvm from tvm import relay @@ -23,7 +23,7 @@ from tvm.relay.build_module import bind_params_by_name from tvm.relay.expr import const -from ...dataflow_pattern import is_constant, is_expr, is_op, wildcard +from ...dataflow_pattern import is_constant, is_expr, is_op, is_tuple, wildcard from ..strategy.generic import is_depthwise_conv2d from .register import register_pattern_table @@ -42,7 +42,7 @@ def is_arm_compute_runtime_enabled(): return False -def partition_for_arm_compute_lib(mod, params=None, **opts): +def partition_for_arm_compute_lib(mod, params=None, disabled_ops=["concatenate"], **opts): """Partition the graph greedily offloading supported operators to Arm Compute Library. @@ -52,6 +52,8 @@ def partition_for_arm_compute_lib(mod, params=None, **opts): The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. + disabled_ops : Optional[list] + Ops do not want to offload to ACL. Returns ------- @@ -63,7 +65,7 @@ def partition_for_arm_compute_lib(mod, params=None, **opts): seq = tvm.transform.Sequential( [ transform.InferType(), - transform.MergeComposite(arm_compute_lib_pattern_table()), + transform.MergeComposite(arm_compute_lib_pattern_table(disabled_ops)), transform.AnnotateTarget("arm_compute_lib", False), transform.PartitionGraph(), ] @@ -128,7 +130,7 @@ def convert_conv(attrs, inputs, tinfos, desired_layouts): @register_pattern_table("arm_compute_lib") -def arm_compute_lib_pattern_table(): +def arm_compute_lib_pattern_table(disabled_ops=["concatenate"]): """Get the ACL pattern table.""" def conv_pattern(): @@ -220,6 +222,17 @@ def l2_pool2d_pattern(): pattern = is_op("sqrt")(pattern) return pattern + def concatenate_pattern(): + """Create an concatenate pattern from equivalent relay operators. + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the concatenate pattern. + """ + pattern = is_op("concatenate")(is_tuple(None)) + return pattern + def check_conv(extract): """Check conv pattern is supported by ACL.""" call = extract @@ -266,6 +279,19 @@ def check_l2_pool2d(extract): pool = extract.args[0] return avg_pool2d(pool) + def check_concatenate(expr): + """Check concatenate pattern is supported by ACL.""" + if "concatenate" in disabled_ops: + return False + attrs, type_args = expr.attrs, expr.type_args + for idx in range(len(type_args[0].fields)): + if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: + return False + # ACL concatenate only supports maximum 4 dimensions input tensor + if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: + return False + return True + return [ ("arm_compute_lib.conv2d", conv_pattern(), check_conv), ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv), @@ -274,6 +300,7 @@ def check_l2_pool2d(extract): ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv), ("arm_compute_lib.avg_pool2d", avg_pool2d_pattern(), check_avg_pool2d), ("arm_compute_lib.l2_pool2d", l2_pool2d_pattern(), check_l2_pool2d), + ("arm_compute_lib.concatenate", concatenate_pattern(), check_concatenate), ] diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 8098c8d512741..842ede3bf20b8 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -99,6 +99,8 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { json_node = CreateCompositeAvgPool2DJSONNode(cn); } else if (name == "arm_compute_lib.l2_pool2d") { json_node = CreateCompositeL2Pool2DJSONNode(cn); + } else if (name == "arm_compute_lib.concatenate") { + return AddCommonSingleJSONNode(cn, "concatenate"); } else { LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name; } @@ -342,6 +344,30 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { SetCallNodeAttribute(json_node, avg_pool); return json_node; } + + /*! + * \brief Create a JSON representation of a single operator. + * \param cn The call to be represented. + * \param name The name of the operator. + * \return A list of graph entry nodes. + */ + std::vector AddCommonSingleJSONNode(const CallNode* cn, std::string name) { + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const auto* fn = cn->op.as(); + ICHECK(fn); + const auto* callNode = fn->body.as(); + ICHECK(callNode); + SetCallNodeAttribute(node, callNode); + return AddNode(node, GetRef(cn)); + } }; /*! diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index a336cf494f4b8..5687e687cfb64 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -31,6 +31,7 @@ #ifdef TVM_GRAPH_EXECUTOR_ARM_COMPUTE_LIB #include #include +#include #include #include #include @@ -91,12 +92,21 @@ class ACLRuntime : public JSONRuntimeBase { * \return Status of inference. */ void Run() override { - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto nid = input_nodes_[i]; - uint32_t eid = EntryID(nid, 0); + for (size_t nid_idx = 0; nid_idx < input_nodes_.size(); ++nid_idx) { + auto nid = input_nodes_[nid_idx]; if (nodes_[nid].GetOpType() == "input") { - void* data = data_entry_[eid]->data; - CheckACLError(layer_.inputs[i].allocator()->import_memory(data)); + for (uint32_t eid_idx = 0; eid_idx < nodes_[nid].GetNumOutput(); eid_idx++) { + uint32_t eid = EntryID(nid, eid_idx); + void* data = data_entry_[eid]->data; + auto key = std::pair(nid, eid_idx); + if (layer_.json_inputid_to_layer_inputid.count(key) > 0) { + CheckACLError( + layer_.inputs[layer_.json_inputid_to_layer_inputid[key]].allocator()->import_memory( + data)); + } else { + CheckACLError(layer_.inputs[nid_idx].allocator()->import_memory(data)); + } + } } } @@ -149,6 +159,8 @@ class ACLRuntime : public JSONRuntimeBase { CreateMaximumLayer(&layer_, node); } else if ("add" == op_name || "qnn.add" == op_name) { CreateAddLayer(&layer_, node); + } else if ("concatenate" == op_name) { + CreateConcatenateLayer(&layer_, node); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -166,6 +178,9 @@ class ACLRuntime : public JSONRuntimeBase { std::shared_ptr function; std::vector inputs; std::vector outputs; + // maps the input index of JSON node to the index of the ACL layer's inputs + // this is optional (i.e.only when an operator uses the eid index) + std::map, uint32_t> json_inputid_to_layer_inputid; }; /*! @@ -175,17 +190,25 @@ class ACLRuntime : public JSONRuntimeBase { * \param tensor The tensor to represent. * \param scale (optional) The scale of the tensor as an input. * \param offset (optional) The offset of the tensor as an input. + * \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after + * setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but + * _num_dimensions should be 3 rather than 1. + * \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of + * dimensions of the shape. * \return ACL Tensor. */ arm_compute::Tensor MakeACLTensorFromJSONEntry(const JSONGraphNodeEntry& tensor, JSONGraphNodeEntry* scale = nullptr, - JSONGraphNodeEntry* offset = nullptr) { + JSONGraphNodeEntry* offset = nullptr, + bool apply_dim_correction = true, + bool increase_dim_unit = true) { JSONGraphNode node = nodes_[tensor.id_]; void* node_data = nullptr; if (node.GetOpType() == "const") { node_data = data_entry_[EntryID(tensor)]->data; } - return MakeACLTensorFromJSONNode(node, scale, offset, node_data); + return MakeACLTensorFromJSONNode(node, scale, offset, node_data, apply_dim_correction, + increase_dim_unit, tensor.index_); } /*! @@ -196,19 +219,26 @@ class ACLRuntime : public JSONRuntimeBase { * \param scale (optional) The scale of the tensor as an input. * \param offset (optional) The offset of the tensor as an input. * \param data (optional) Constant data of input node. + * \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after + * setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but + * _num_dimensions should be 3 rather than 1. + * \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of + * dimensions of the shape. + * \param entry_index The entry index. * \return ACL Tensor. */ - arm_compute::Tensor MakeACLTensorFromJSONNode(const JSONGraphNode& node, - JSONGraphNodeEntry* scale = nullptr, - JSONGraphNodeEntry* offset = nullptr, - void* data = nullptr) { + arm_compute::Tensor MakeACLTensorFromJSONNode( + const JSONGraphNode& node, JSONGraphNodeEntry* scale = nullptr, + JSONGraphNodeEntry* offset = nullptr, void* data = nullptr, bool apply_dim_correction = true, + bool increase_dim_unit = true, uint32_t entry_index = 0) { const DLTensor* scale_data = nullptr; const DLTensor* offset_data = nullptr; if (scale && offset) { scale_data = data_entry_[EntryID(*scale)]; offset_data = data_entry_[EntryID(*offset)]; } - return MakeACLTensor(node, data, scale_data, offset_data); + return MakeACLTensor(node, data, scale_data, offset_data, apply_dim_correction, + increase_dim_unit, entry_index); } /*! @@ -510,6 +540,34 @@ class ACLRuntime : public JSONRuntimeBase { layer->function = f; } + /*! + * \brief Create a Concatenate layer. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.c + * \param node The JSON representation of the operator. + */ + void CreateConcatenateLayer(CachedLayer* layer, const JSONGraphNode& node) { + std::vector axis = node.GetAttr>("axis"); + std::vector inputs; + for (auto input : node.GetInputs()) { + layer->inputs.push_back(MakeACLTensorFromJSONEntry(input, nullptr, nullptr, false)); + layer->json_inputid_to_layer_inputid[std::pair(input.id_, input.index_)] = + layer->inputs.size() - 1; + } + for (size_t i = 0; i < layer->inputs.size(); i++) { + inputs.push_back(&layer->inputs[i]); + } + layer->outputs.push_back(MakeACLTensorFromJSONNode(node)); + int dimNum = layer->inputs[0].info()->num_dimensions(); + auto function = std::make_shared(); + // the shape of input tensor will be reversed after passing to ACL + // for example a tensor with shape [1, 2, 3, 4] will be changed to + // [4, 3, 2, 1] at ACL side. So the axis here should be preprocessed. + auto a = std::stoi(axis[0]); + function->configure(inputs, &layer->outputs[0], a < 0 ? -a - 1 : dimNum - a - 1); + layer->function = function; + } + /*! \brief Allow ACL functions to request auxiliary memory from TVM. */ ACLAllocator allocator_; /*! diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc index 3b2620987ab0a..238b7355de263 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -40,11 +40,14 @@ void CheckACLError(const arm_compute::Status& status) { } arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data, - const DLTensor* scale, const DLTensor* offset) { + const DLTensor* scale, const DLTensor* offset, + bool apply_dim_correction, bool increase_dim_unit, + uint32_t entry_index) { arm_compute::Tensor tensor; - std::vector shape = tensor_rep.GetOpShape()[0]; - DLDataType dtype = tensor_rep.GetOpDataType()[0]; - arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset); + std::vector shape = tensor_rep.GetOpShape()[entry_index]; + DLDataType dtype = tensor_rep.GetOpDataType()[entry_index]; + arm_compute::TensorInfo info = + MakeACLTensorInfo(shape, dtype, scale, offset, apply_dim_correction, increase_dim_unit); info.set_is_resizable(false); tensor.allocator()->init(info); if (data != nullptr) { @@ -55,10 +58,11 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data, arm_compute::TensorInfo MakeACLTensorInfo(const std::vector& shape, const DLDataType& dtype, const DLTensor* scale, - const DLTensor* offset) { + const DLTensor* offset, bool apply_dim_correction, + bool increase_dim_unit) { arm_compute::TensorShape acl_shape; for (unsigned int i = shape.size(); i > 0; --i) { - acl_shape.set(shape.size() - i, shape[i - 1]); + acl_shape.set(shape.size() - i, shape[i - 1], apply_dim_correction, increase_dim_unit); } arm_compute::DataType acl_dtype = MakeACLDataType(dtype); arm_compute::TensorInfo info(acl_shape, 1, acl_dtype, arm_compute::DataLayout::NHWC); diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.h b/src/runtime/contrib/arm_compute_lib/acl_utils.h index dbb006fbb3478..a553839240e4f 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.h +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.h @@ -63,8 +63,9 @@ void CheckACLError(const arm_compute::Status& status); * \return arm_compute::Tensor. */ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = nullptr, - const DLTensor* scale = nullptr, - const DLTensor* offset = nullptr); + const DLTensor* scale = nullptr, const DLTensor* offset = nullptr, + bool apply_dim_correction = true, bool increase_dim_unit = true, + uint32_t entry_index = 0); /*! * \brief Make an acl tensor info object from JSON tensor @@ -78,7 +79,9 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = */ arm_compute::TensorInfo MakeACLTensorInfo(const std::vector& shape, const DLDataType& dtype, const DLTensor* scale = nullptr, - const DLTensor* offset = nullptr); + const DLTensor* offset = nullptr, + bool apply_dim_correction = true, + bool increase_dim_unit = true); /*! * \brief Create a memory manager for use with a layer that diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 1735d85692150..0c6d0f6d71363 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -186,6 +186,7 @@ class JSONRuntimeBase : public ModuleNode { for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { input_var_eid_.push_back(EntryID(nid, j)); } + nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size()); } else { ICHECK_EQ(nodes_[nid].op_type_, "const"); auto pos = std::find(std::begin(const_names_), std::end(const_names_), name); diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index e582874d1de27..314da972c0498 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -163,13 +163,23 @@ def skip_codegen_test(): return True -def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_partitions=1): +def build_module( + mod, + target, + params=None, + enable_acl=True, + tvm_ops=0, + acl_partitions=1, + disabled_ops=["concatenate"], +): """Build module with option to build for ACL.""" if isinstance(mod, tvm.relay.expr.Call): mod = tvm.IRModule.from_expr(mod) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): if enable_acl: - mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params) + mod = arm_compute_lib.partition_for_arm_compute_lib( + mod, params, disabled_ops=disabled_ops + ) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( tvm_op_count, tvm_ops @@ -199,13 +209,16 @@ def build_and_run( tvm_ops=0, acl_partitions=1, config=None, + disabled_ops=["concatenate"], ): """Build and run the relay module.""" if config is None: config = {} try: - lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions) + lib = build_module( + mod, device.target, params, enable_acl, tvm_ops, acl_partitions, disabled_ops + ) except Exception as e: err_msg = "The module could not be built.\n" if config: @@ -276,9 +289,16 @@ def verify_codegen( num_acl_modules=1, tvm_ops=0, target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon", + disabled_ops=["concatenate"], ): """Check acl codegen against a known good output.""" - module = build_module(module, target, tvm_ops=tvm_ops, acl_partitions=num_acl_modules) + module = build_module( + module, + target, + tvm_ops=tvm_ops, + acl_partitions=num_acl_modules, + disabled_ops=disabled_ops, + ) acl_modules = extract_acl_modules(module) assert len(acl_modules) == num_acl_modules, ( diff --git a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py new file mode 100644 index 0000000000000..deba26a0db560 --- /dev/null +++ b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm Compute Library integration concatenate tests.""" + +import numpy as np + +import tvm +from tvm import relay +from tvm import testing + +from test_arm_compute_lib.infrastructure import ( + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + verify_codegen, +) +from test_arm_compute_lib.infrastructure import Device + + +def _get_model(input_shape_a, input_shape_b, input_shape_c, axis, dtype, var_names): + """Return a model and any parameters it may have.""" + a = relay.var(next(var_names), shape=input_shape_a, dtype=dtype) + b = relay.var(next(var_names), shape=input_shape_b, dtype=dtype) + c = relay.var(next(var_names), shape=input_shape_c, dtype=dtype) + out = relay.concatenate([a, b, c], axis) + return out + + +def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dtype): + node = { + "op": "kernel", + "name": "concatenate", + "inputs": [ + [0, 0, 0], + [1, 0, 0], + [2, 0, 0], + ], + "attrs": { + "num_outputs": "1", + "num_inputs": "3", + "dtype": [[dtype]], + "axis": [[str(axis)]], + "shape": [[[6, 234, 234, 256]]], + }, + } + + input_a = { + "op": "input", + "name": "", + "attrs": { + "shape": [[input_shape_a]], + "dtype": [[dtype]], + }, + } + + input_b = { + "op": "input", + "name": "", + "attrs": { + "shape": [[input_shape_b]], + "dtype": [[dtype]], + }, + } + + input_c = { + "op": "input", + "name": "", + "attrs": { + "shape": [[input_shape_c]], + "dtype": [[dtype]], + }, + } + return [input_a, input_b, input_c, node] + + +def test_concatenate(): + Device.load("test_config.json") + + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + for input_shape_a, input_shape_b, input_shape_c, axis, dtype in [ + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], 0, "float32"), + ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "float32"), + ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "float32"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], -4, "float32"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], 0, "uint8"), + ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "uint8"), + ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "uint8"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], -4, "uint8"), + ]: + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)), + "b": tvm.nd.array(np.random.randn(*input_shape_b).astype(dtype)), + "c": tvm.nd.array(np.random.randn(*input_shape_c).astype(dtype)), + } + func = _get_model( + inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) + ) + for acl in [False, True]: + outputs.append( + build_and_run(func, inputs, 1, None, device, enable_acl=acl, disabled_ops=[])[0] + ) + + config = { + "input_shape_a": input_shape_a, + "input_shape_b": input_shape_b, + "input_shape_c": input_shape_c, + "axis": axis, + "dtype": dtype, + } + verify(outputs, atol=1e-7, rtol=1e-7, config=config) + + +def test_codegen_concatenate(): + if skip_codegen_test(): + return + shape_a = [1, 234, 234, 256] + shape_b = [2, 234, 234, 256] + shape_c = [3, 234, 234, 256] + axis = 0 + inputs = {"a", "b", "c"} + for dtype in ["float32"]: + args = (shape_a, shape_b, shape_c, axis, dtype) + func = _get_model(*args, iter(inputs)) + exp_codegen = _get_expected_codegen(*args) + verify_codegen(func, exp_codegen, 1, disabled_ops=[]) + + +if __name__ == "__main__": + test_concatenate() + test_codegen_concatenate()