Skip to content

Commit

Permalink
[BYOC][ACL] Fix list is not supported as an input node (apache#10801)
Browse files Browse the repository at this point in the history
* [BYOC][ACL] Fix list is not supported as an input node

* fix clang lint error

* fix compile warnning

* fix python module import error

* rename concatenate test file

* fix always MakeACLTensor with same eid 0

* do not offload concat default

* fix concattnate test failure

* fix test failure

* fix lint error

* fix lint

* remove global var offload_concat

* support concatenate with pattern table mechanism

* disable pylint dangerous-default-value warning

Co-authored-by: XuZhi <xuzhi.xu@alibaba-inc.com>
  • Loading branch information
2 people authored and Lucien0 committed Apr 19, 2022
1 parent 71b595d commit 69d0da4
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 30 deletions.
37 changes: 32 additions & 5 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument, dangerous-default-value
"""Arm Compute Library supported operators."""
import tvm
from tvm import relay
Expand All @@ -23,7 +23,7 @@
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.expr import const

from ...dataflow_pattern import is_constant, is_expr, is_op, wildcard
from ...dataflow_pattern import is_constant, is_expr, is_op, is_tuple, wildcard
from ..strategy.generic import is_depthwise_conv2d
from .register import register_pattern_table

Expand All @@ -42,7 +42,7 @@ def is_arm_compute_runtime_enabled():
return False


def partition_for_arm_compute_lib(mod, params=None, **opts):
def partition_for_arm_compute_lib(mod, params=None, disabled_ops=["concatenate"], **opts):
"""Partition the graph greedily offloading supported
operators to Arm Compute Library.
Expand All @@ -52,6 +52,8 @@ def partition_for_arm_compute_lib(mod, params=None, **opts):
The module to run passes on.
params : Optional[Dict[str, NDArray]]
Constant input parameters.
disabled_ops : Optional[list]
Ops do not want to offload to ACL.
Returns
-------
Expand All @@ -63,7 +65,7 @@ def partition_for_arm_compute_lib(mod, params=None, **opts):
seq = tvm.transform.Sequential(
[
transform.InferType(),
transform.MergeComposite(arm_compute_lib_pattern_table()),
transform.MergeComposite(arm_compute_lib_pattern_table(disabled_ops)),
transform.AnnotateTarget("arm_compute_lib", False),
transform.PartitionGraph(),
]
Expand Down Expand Up @@ -128,7 +130,7 @@ def convert_conv(attrs, inputs, tinfos, desired_layouts):


@register_pattern_table("arm_compute_lib")
def arm_compute_lib_pattern_table():
def arm_compute_lib_pattern_table(disabled_ops=["concatenate"]):
"""Get the ACL pattern table."""

def conv_pattern():
Expand Down Expand Up @@ -220,6 +222,17 @@ def l2_pool2d_pattern():
pattern = is_op("sqrt")(pattern)
return pattern

def concatenate_pattern():
"""Create an concatenate pattern from equivalent relay operators.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the concatenate pattern.
"""
pattern = is_op("concatenate")(is_tuple(None))
return pattern

def check_conv(extract):
"""Check conv pattern is supported by ACL."""
call = extract
Expand Down Expand Up @@ -266,6 +279,19 @@ def check_l2_pool2d(extract):
pool = extract.args[0]
return avg_pool2d(pool)

def check_concatenate(expr):
"""Check concatenate pattern is supported by ACL."""
if "concatenate" in disabled_ops:
return False
attrs, type_args = expr.attrs, expr.type_args
for idx in range(len(type_args[0].fields)):
if type_args[0].fields[idx].dtype not in ["float32", "uint8"]:
return False
# ACL concatenate only supports maximum 4 dimensions input tensor
if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]:
return False
return True

return [
("arm_compute_lib.conv2d", conv_pattern(), check_conv),
("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv),
Expand All @@ -274,6 +300,7 @@ def check_l2_pool2d(extract):
("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv),
("arm_compute_lib.avg_pool2d", avg_pool2d_pattern(), check_avg_pool2d),
("arm_compute_lib.l2_pool2d", l2_pool2d_pattern(), check_l2_pool2d),
("arm_compute_lib.concatenate", concatenate_pattern(), check_concatenate),
]


Expand Down
26 changes: 26 additions & 0 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
json_node = CreateCompositeAvgPool2DJSONNode(cn);
} else if (name == "arm_compute_lib.l2_pool2d") {
json_node = CreateCompositeL2Pool2DJSONNode(cn);
} else if (name == "arm_compute_lib.concatenate") {
return AddCommonSingleJSONNode(cn, "concatenate");
} else {
LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
}
Expand Down Expand Up @@ -342,6 +344,30 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
SetCallNodeAttribute(json_node, avg_pool);
return json_node;
}

/*!
* \brief Create a JSON representation of a single operator.
* \param cn The call to be represented.
* \param name The name of the operator.
* \return A list of graph entry nodes.
*/
std::vector<JSONGraphNodeEntry> AddCommonSingleJSONNode(const CallNode* cn, std::string name) {
std::vector<JSONGraphNodeEntry> inputs;
for (const auto& arg : cn->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}
auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);

const auto* fn = cn->op.as<FunctionNode>();
ICHECK(fn);
const auto* callNode = fn->body.as<CallNode>();
ICHECK(callNode);
SetCallNodeAttribute(node, callNode);
return AddNode(node, GetRef<Expr>(cn));
}
};

/*!
Expand Down
82 changes: 70 additions & 12 deletions src/runtime/contrib/arm_compute_lib/acl_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#ifdef TVM_GRAPH_EXECUTOR_ARM_COMPUTE_LIB
#include <arm_compute/core/Types.h>
#include <arm_compute/runtime/NEON/functions/NEArithmeticAddition.h>
#include <arm_compute/runtime/NEON/functions/NEConcatenateLayer.h>
#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEElementwiseOperations.h>
Expand Down Expand Up @@ -91,12 +92,21 @@ class ACLRuntime : public JSONRuntimeBase {
* \return Status of inference.
*/
void Run() override {
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
uint32_t eid = EntryID(nid, 0);
for (size_t nid_idx = 0; nid_idx < input_nodes_.size(); ++nid_idx) {
auto nid = input_nodes_[nid_idx];
if (nodes_[nid].GetOpType() == "input") {
void* data = data_entry_[eid]->data;
CheckACLError(layer_.inputs[i].allocator()->import_memory(data));
for (uint32_t eid_idx = 0; eid_idx < nodes_[nid].GetNumOutput(); eid_idx++) {
uint32_t eid = EntryID(nid, eid_idx);
void* data = data_entry_[eid]->data;
auto key = std::pair<uint32_t, uint32_t>(nid, eid_idx);
if (layer_.json_inputid_to_layer_inputid.count(key) > 0) {
CheckACLError(
layer_.inputs[layer_.json_inputid_to_layer_inputid[key]].allocator()->import_memory(
data));
} else {
CheckACLError(layer_.inputs[nid_idx].allocator()->import_memory(data));
}
}
}
}

Expand Down Expand Up @@ -149,6 +159,8 @@ class ACLRuntime : public JSONRuntimeBase {
CreateMaximumLayer(&layer_, node);
} else if ("add" == op_name || "qnn.add" == op_name) {
CreateAddLayer(&layer_, node);
} else if ("concatenate" == op_name) {
CreateConcatenateLayer(&layer_, node);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand All @@ -166,6 +178,9 @@ class ACLRuntime : public JSONRuntimeBase {
std::shared_ptr<arm_compute::IFunction> function;
std::vector<arm_compute::Tensor> inputs;
std::vector<arm_compute::Tensor> outputs;
// maps the input index of JSON node to the index of the ACL layer's inputs
// this is optional (i.e.only when an operator uses the eid index)
std::map<std::pair<uint32_t, uint32_t>, uint32_t> json_inputid_to_layer_inputid;
};

/*!
Expand All @@ -175,17 +190,25 @@ class ACLRuntime : public JSONRuntimeBase {
* \param tensor The tensor to represent.
* \param scale (optional) The scale of the tensor as an input.
* \param offset (optional) The offset of the tensor as an input.
* \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after
* setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but
* _num_dimensions should be 3 rather than 1.
* \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of
* dimensions of the shape.
* \return ACL Tensor.
*/
arm_compute::Tensor MakeACLTensorFromJSONEntry(const JSONGraphNodeEntry& tensor,
JSONGraphNodeEntry* scale = nullptr,
JSONGraphNodeEntry* offset = nullptr) {
JSONGraphNodeEntry* offset = nullptr,
bool apply_dim_correction = true,
bool increase_dim_unit = true) {
JSONGraphNode node = nodes_[tensor.id_];
void* node_data = nullptr;
if (node.GetOpType() == "const") {
node_data = data_entry_[EntryID(tensor)]->data;
}
return MakeACLTensorFromJSONNode(node, scale, offset, node_data);
return MakeACLTensorFromJSONNode(node, scale, offset, node_data, apply_dim_correction,
increase_dim_unit, tensor.index_);
}

/*!
Expand All @@ -196,19 +219,26 @@ class ACLRuntime : public JSONRuntimeBase {
* \param scale (optional) The scale of the tensor as an input.
* \param offset (optional) The offset of the tensor as an input.
* \param data (optional) Constant data of input node.
* \param apply_dim_correction (Optional) Flag to state whether apply dimension correction after
* setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but
* _num_dimensions should be 3 rather than 1.
* \param increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of
* dimensions of the shape.
* \param entry_index The entry index.
* \return ACL Tensor.
*/
arm_compute::Tensor MakeACLTensorFromJSONNode(const JSONGraphNode& node,
JSONGraphNodeEntry* scale = nullptr,
JSONGraphNodeEntry* offset = nullptr,
void* data = nullptr) {
arm_compute::Tensor MakeACLTensorFromJSONNode(
const JSONGraphNode& node, JSONGraphNodeEntry* scale = nullptr,
JSONGraphNodeEntry* offset = nullptr, void* data = nullptr, bool apply_dim_correction = true,
bool increase_dim_unit = true, uint32_t entry_index = 0) {
const DLTensor* scale_data = nullptr;
const DLTensor* offset_data = nullptr;
if (scale && offset) {
scale_data = data_entry_[EntryID(*scale)];
offset_data = data_entry_[EntryID(*offset)];
}
return MakeACLTensor(node, data, scale_data, offset_data);
return MakeACLTensor(node, data, scale_data, offset_data, apply_dim_correction,
increase_dim_unit, entry_index);
}

/*!
Expand Down Expand Up @@ -510,6 +540,34 @@ class ACLRuntime : public JSONRuntimeBase {
layer->function = f;
}

/*!
* \brief Create a Concatenate layer.
*
* \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.c
* \param node The JSON representation of the operator.
*/
void CreateConcatenateLayer(CachedLayer* layer, const JSONGraphNode& node) {
std::vector<std::string> axis = node.GetAttr<std::vector<std::string>>("axis");
std::vector<const arm_compute::ITensor*> inputs;
for (auto input : node.GetInputs()) {
layer->inputs.push_back(MakeACLTensorFromJSONEntry(input, nullptr, nullptr, false));
layer->json_inputid_to_layer_inputid[std::pair<uint32_t, uint32_t>(input.id_, input.index_)] =
layer->inputs.size() - 1;
}
for (size_t i = 0; i < layer->inputs.size(); i++) {
inputs.push_back(&layer->inputs[i]);
}
layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
int dimNum = layer->inputs[0].info()->num_dimensions();
auto function = std::make_shared<arm_compute::NEConcatenateLayer>();
// the shape of input tensor will be reversed after passing to ACL
// for example a tensor with shape [1, 2, 3, 4] will be changed to
// [4, 3, 2, 1] at ACL side. So the axis here should be preprocessed.
auto a = std::stoi(axis[0]);
function->configure(inputs, &layer->outputs[0], a < 0 ? -a - 1 : dimNum - a - 1);
layer->function = function;
}

/*! \brief Allow ACL functions to request auxiliary memory from TVM. */
ACLAllocator allocator_;
/*!
Expand Down
16 changes: 10 additions & 6 deletions src/runtime/contrib/arm_compute_lib/acl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ void CheckACLError(const arm_compute::Status& status) {
}

arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data,
const DLTensor* scale, const DLTensor* offset) {
const DLTensor* scale, const DLTensor* offset,
bool apply_dim_correction, bool increase_dim_unit,
uint32_t entry_index) {
arm_compute::Tensor tensor;
std::vector<int64_t> shape = tensor_rep.GetOpShape()[0];
DLDataType dtype = tensor_rep.GetOpDataType()[0];
arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset);
std::vector<int64_t> shape = tensor_rep.GetOpShape()[entry_index];
DLDataType dtype = tensor_rep.GetOpDataType()[entry_index];
arm_compute::TensorInfo info =
MakeACLTensorInfo(shape, dtype, scale, offset, apply_dim_correction, increase_dim_unit);
info.set_is_resizable(false);
tensor.allocator()->init(info);
if (data != nullptr) {
Expand All @@ -55,10 +58,11 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data,

arm_compute::TensorInfo MakeACLTensorInfo(const std::vector<int64_t>& shape,
const DLDataType& dtype, const DLTensor* scale,
const DLTensor* offset) {
const DLTensor* offset, bool apply_dim_correction,
bool increase_dim_unit) {
arm_compute::TensorShape acl_shape;
for (unsigned int i = shape.size(); i > 0; --i) {
acl_shape.set(shape.size() - i, shape[i - 1]);
acl_shape.set(shape.size() - i, shape[i - 1], apply_dim_correction, increase_dim_unit);
}
arm_compute::DataType acl_dtype = MakeACLDataType(dtype);
arm_compute::TensorInfo info(acl_shape, 1, acl_dtype, arm_compute::DataLayout::NHWC);
Expand Down
9 changes: 6 additions & 3 deletions src/runtime/contrib/arm_compute_lib/acl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ void CheckACLError(const arm_compute::Status& status);
* \return arm_compute::Tensor.
*/
arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = nullptr,
const DLTensor* scale = nullptr,
const DLTensor* offset = nullptr);
const DLTensor* scale = nullptr, const DLTensor* offset = nullptr,
bool apply_dim_correction = true, bool increase_dim_unit = true,
uint32_t entry_index = 0);

/*!
* \brief Make an acl tensor info object from JSON tensor
Expand All @@ -78,7 +79,9 @@ arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data =
*/
arm_compute::TensorInfo MakeACLTensorInfo(const std::vector<int64_t>& shape,
const DLDataType& dtype, const DLTensor* scale = nullptr,
const DLTensor* offset = nullptr);
const DLTensor* offset = nullptr,
bool apply_dim_correction = true,
bool increase_dim_unit = true);

/*!
* \brief Create a memory manager for use with a layer that
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class JSONRuntimeBase : public ModuleNode {
for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
input_var_eid_.push_back(EntryID(nid, j));
}
nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size());
} else {
ICHECK_EQ(nodes_[nid].op_type_, "const");
auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
Expand Down
Loading

0 comments on commit 69d0da4

Please sign in to comment.