Skip to content

Commit

Permalink
[RELAY]Vision ops for yolo
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Oct 20, 2018
1 parent 4300bbc commit 89d965d
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ This level enables additional math and transform operators.
:nosignatures:

tvm.relay.image.resize
tvm.relay.vision.yolo_regorg
tvm.relay.vision.yolo_region
tvm.relay.vision.yolov3_yolo


Level 1 Definitions
Expand Down Expand Up @@ -192,3 +195,6 @@ Level 4 Definitions
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
autofunction:: tvm.relay.vision.yolo_regorg
autofunction:: tvm.relay.vision.yolo_region
autofunction:: tvm.relay.vision.yolov3_yolo
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
}
};

/*! \brief Attributes used in yolo reorg operators */
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
IndexExpr stride;

TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
TVM_ATTR_FIELD(stride)
.set_default(1)
.describe("Stride value for yolo reorg");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
1 change: 1 addition & 0 deletions python/tvm/relay/op/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs

from .multibox import *
from .yolo import *
53 changes: 53 additions & 0 deletions python/tvm/relay/op/vision/yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Yolo operations."""
from . import _make

def yolo_reorg(data, stride=1):
"""Yolo reorg operation. This layer reorganize the output based on the stride value.
Its function is mostly shape transform.
Parameters
----------
data : relay.Expr
The input data tensor.
stride : int
The stride value for reorganisation.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.yolo_reorg(data, stride)


def yolo_region(data):
"""Yolo region operation used for detection.
Parameters
----------
data : relay.Expr
The input data tensor.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.yolo_region(data)


def yolov3_yolo(data):
"""Yolo operation used for detection
Parameters
----------
data : relay.Expr
The input data tensor.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.yolov3_yolo(data)
111 changes: 111 additions & 0 deletions src/relay/op/vision/yolo.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*!
* Copyright (c) 2018 by Contributors
* \file yolo.cc
* \brief Yolo related operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
#include <vector>
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(YoloReorgAttrs);

/*!
* \brief YoloReorgRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool YoloReorgRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

const YoloReorgAttrs* param = attrs.as<YoloReorgAttrs>();
CHECK(param != nullptr);

CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = oshape[2] / param->stride;
oshape[3] = oshape[3] / param->stride;
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Expr MakeYoloReorg(Expr data,
IndexExpr stride) {
auto attrs = make_node<YoloReorgAttrs>();
attrs->stride = stride;
static const Op& op = Op::Get("vision.yolo_reorg");
return CallNode::make(op, {data}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
});


RELAY_REGISTER_OP("vision.yolo_reorg")
.describe(R"doc("Yolo reorg operation. This layer reorganize the output.
Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input tensor.")
.set_num_inputs(1)
.set_support_level(5)
.set_attrs_type_key("relay.attrs.YoloReorgAttrs")
.add_type_rel("YoloReorg", YoloReorgRel);


Expr MakeYoloRegion(Expr data) {
static const Op& op = Op::Get("vision.yolo_region");
return CallNode::make(op, {data}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.vision._make.yolo_region")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 1>(MakeYoloRegion, args, rv);
});


RELAY_REGISTER_OP("vision.yolo_region")
.describe(R"doc("Yolo region operation used for detection."
)doc" TVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input tensor.")
.set_num_inputs(1)
.set_support_level(5)
.add_type_rel("Identity", IdentityRel);


Expr MakeYolov3Yolo(Expr data) {
static const Op& op = Op::Get("vision.yolov3_yolo");
return CallNode::make(op, {data}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.vision._make.yolov3_yolo")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 1>(MakeYolov3Yolo, args, rv);
});


RELAY_REGISTER_OP("vision.yolov3_yolo")
.describe(R"doc("Yolov3 operation used for detection."
)doc" TVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input tensor.")
.set_num_inputs(1)
.set_support_level(5)
.add_type_rel("Identity", IdentityRel);

} // namespace relay
} // namespace tvm
49 changes: 49 additions & 0 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,55 @@ def test_multibox_prior():
(1, h * w, 4), "float32")


def test_yolo_reorg():
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.yolo_reorg(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32")

ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))

with ib.function(x) as func:
ib.ret(relay.vision.yolo_reorg(x, stride=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c*2*2, h/2, w/2), "float32")


def test_yolo_region():
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.yolo_region(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32")


def test_yolov3_yolo():
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.yolov3_yolo(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32")


if __name__ == "__main__":
test_resize_infer_type()
test_multibox_prior()
test_yolo_reorg()
test_yolo_region()
test_yolov3_yolo()

0 comments on commit 89d965d

Please sign in to comment.