Skip to content

Commit

Permalink
[Relay][Dyn] Dynamic TopK Op (#6008)
Browse files Browse the repository at this point in the history
* add dynamic topk op

* add topk to dynamic_to_static pass

* fix TF test

* fix pylint
  • Loading branch information
Matthew Brookhart authored Jul 10, 2020
1 parent ba04c6a commit 474d472
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 59 deletions.
35 changes: 5 additions & 30 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,6 @@
register_strategy("topk", strategy.topk_strategy)
register_pattern("topk", OpPattern.OPAQUE)

@script
def _topk_shape_func_input_data(data, k, axis):
ndim = len(data.shape)
val_out = output_tensor((ndim,), "int64")
indices_out = output_tensor((ndim,), "int64")

for i in const_range(ndim):
if i != axis:
val_out[i] = int64(data.shape[i])
indices_out[i] = int64(data.shape[i])
else:
if k[0] < 1:
val_out[i] = int64(data.shape[i])
indices_out[i] = int64(data.shape[i])
else:
val_out[i] = int64(k[0])
indices_out[i] = int64(k[0])
return val_out, indices_out

@script
def _topk_shape_func_input_shape(data_shape, k, axis):
ndim = data_shape.shape[0]
Expand All @@ -72,22 +53,16 @@ def _topk_shape_func_input_shape(data_shape, k, axis):
indices_out[i] = int64(k)
return val_out, indices_out

@_reg.register_shape_func("topk", True)
@_reg.register_shape_func("topk", False)
def topk_shape_func(attrs, inputs, _):
"""
Shape func for topk.
"""
axis = attrs.axis
if attrs.k is not None:
if axis < 0:
axis += inputs[0].shape[0]
val_out, indices_out = \
_topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
else:
if axis < 0:
axis += len(inputs[0].shape)
val_out, indices_out = \
_topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
if axis < 0:
axis += inputs[0].shape[0]
val_out, indices_out = \
_topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
ret_type = attrs.ret_type
if ret_type == "both":
ret = [val_out, indices_out]
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relay/op/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
import numpy as np
from . import _make
from ..expr import TupleWrapper, const
from .dyn import _make as _dyn_make
from ..expr import TupleWrapper, Expr, Constant

def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
Expand Down Expand Up @@ -82,9 +84,12 @@ def topk(data, k=1, axis=-1, ret_type="both",
out : relay.Expr or List[relay.Expr]
The computed result.
"""
if isinstance(k, int):
k = const(k, "int64")
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if isinstance(k, Constant):
k = np.asscalar(k.data.asnumpy())
if isinstance(k, Expr):
out = _dyn_make.topk(data, k, axis, ret_type, is_ascend, dtype)
else:
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay namespace containing dynamic ops."""

from . import _algorithm
from . import _transform
71 changes: 71 additions & 0 deletions python/tvm/relay/op/dyn/_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"Definition of classic algorithms"
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import

from tvm.te.hybrid import script
from tvm.runtime import convert

from .. import strategy
from .. import op as _reg
from ..op import OpPattern, register_pattern
from ..op import register_strategy

# topk
register_strategy("dyn.topk", strategy.topk_strategy)
register_pattern("dyn.topk", OpPattern.OPAQUE)

@script
def _topk_shape_func_input_data(data, k, axis):
ndim = len(data.shape)
val_out = output_tensor((ndim,), "int64")
indices_out = output_tensor((ndim,), "int64")

for i in const_range(ndim):
if i != axis:
val_out[i] = int64(data.shape[i])
indices_out[i] = int64(data.shape[i])
else:
if k[0] < 1:
val_out[i] = int64(data.shape[i])
indices_out[i] = int64(data.shape[i])
else:
val_out[i] = int64(k[0])
indices_out[i] = int64(k[0])
return val_out, indices_out

@_reg.register_shape_func("dyn.topk", True)
def topk_shape_func(attrs, inputs, _):
"""
Shape func for topk.
"""
axis = attrs.axis
if axis < 0:
axis += len(inputs[0].shape)
val_out, indices_out = \
_topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))

ret_type = attrs.ret_type
if ret_type == "both":
ret = [val_out, indices_out]
elif ret_type == "values":
ret = [val_out]
else:
ret = [indices_out]

return ret
3 changes: 2 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,10 @@ def argsort_strategy(attrs, inputs, out_type, target):
def wrap_compute_topk(topi_compute):
"""Wrap topk compute"""
def _compute_topk(attrs, inputs, out_type):
k = inputs[1]
if attrs.k is not None:
k = attrs.k
else:
k = inputs[1]
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
Expand Down
9 changes: 1 addition & 8 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) {
return false;
}

if (op->name == "topk") {
if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
if (attrs->k) {
// If k attribute exists, it isn't data dependant.
return false;
}
}
} else if (op->name == "strided_slice") {
if (op->name == "strided_slice") {
if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
if (attrs->begin && attrs->end && attrs->strides) {
// not data dependant if begin, end and strides exist
Expand Down
24 changes: 9 additions & 15 deletions src/relay/op/algorithm/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@

namespace tvm {
namespace relay {
using tir::make_const;

TVM_REGISTER_NODE_TYPE(TopKAttrs);

bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 3);
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
Expand All @@ -48,53 +47,48 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
for (int i = 0; i < ndim; ++i) {
if (i != axis) {
out_shape.push_back(data->shape[i]);
} else if (param->k) {
} else {
const Integer& ck = param->k.value();
if (ck->value < 1) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(ck);
}
} else {
out_shape.push_back(Any());
}
}
auto values_ty = TensorType(out_shape, data->dtype);
auto indices_ty = TensorType(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[2], TupleType({values_ty, indices_ty}));
reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[2], values_ty);
reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
reporter->Assign(types[2], indices_ty);
reporter->Assign(types[1], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}

Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) {
auto attrs = make_object<TopKAttrs>();
if (const auto& ck = k.as<ConstantNode>()) {
attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
}
attrs->k = Integer(k);
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
return Call(op, {data, k}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);

RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_inputs(1)
.set_attrs_type<TopKAttrs>()
.add_argument("data", "Tensor", "Input data.")
.add_argument("k", "Tensor", "Number of top elements.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);

Expand Down
107 changes: 107 additions & 0 deletions src/relay/op/dyn/algorithm/topk.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file topk.cc
* \brief TopK operators
*/
#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/op.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace relay {
namespace dyn {

bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, k, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* k = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get " << types[0];
return false;
}
if (k == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get " << types[1];
return false;
}
CHECK(k->shape.size() <= 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )";
if (k->shape.size() == 1) {
const IntImmNode* k_shape = k->shape[0].as<IntImmNode>();
CHECK(k_shape) << "Parameter k must have static shape";
CHECK_EQ(k_shape->value, 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )";
}
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
axis += ndim;
}
CHECK(axis >= 0 && axis < ndim);
Array<IndexExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
if (i != axis) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(Any());
}
}
auto values_ty = TensorType(out_shape, data->dtype);
auto indices_ty = TensorType(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[2], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[2], values_ty);
} else if (param->ret_type == "indices") {
reporter->Assign(types[2], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}

Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
auto attrs = make_object<TopKAttrs>();
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("dyn.topk");
return Call(op, {data, k}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.topk").set_body_typed(MakeTopK);

RELAY_REGISTER_OP("dyn.topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<TopKAttrs>()
.add_argument("data", "Tensor", "Input data.")
.add_argument("k", "Tensor", "Number of top elements.")
.set_support_level(6)
.add_type_rel("DynTopK", TopKRel);

} // namespace dyn
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 474d472

Please sign in to comment.