Skip to content

Commit

Permalink
compute and schedule updated for yolo reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Nov 22, 2018
1 parent 5a81ae0 commit 4ed21d3
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 14 deletions.
5 changes: 4 additions & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
}
};

/*! \brief Attributes used in yolo reorg operators */
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
IndexExpr stride;
Integer stride;

TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
TVM_ATTR_FIELD(stride)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .multibox import *
from .nms import *
from .yolo import *
from ._yolo import *
9 changes: 9 additions & 0 deletions python/tvm/relay/op/vision/_yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..op import register_schedule, register_pattern
from ..op import schedule_injective, OpPattern

# reorg
register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
register_schedule("vision.yolo_reorg", schedule_injective)
13 changes: 11 additions & 2 deletions src/relay/op/vision/yolo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
#include <topi/vision/reorg.h>
#include <vector>
#include "../op_common.h"
#include "../type_relations.h"
Expand Down Expand Up @@ -42,7 +43,7 @@ bool YoloReorgRel(const Array<Type>& types,
}

Expr MakeYoloReorg(Expr data,
IndexExpr stride) {
Integer stride) {
auto attrs = make_node<YoloReorgAttrs>();
attrs->stride = stride;
static const Op& op = Op::Get("vision.yolo_reorg");
Expand All @@ -63,7 +64,15 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(5)
.set_attrs_type_key("relay.attrs.YoloReorgAttrs")
.add_type_rel("YoloReorg", YoloReorgRel);
.add_type_rel("YoloReorg", YoloReorgRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* params = attrs.as<YoloReorgAttrs>();
CHECK(params != nullptr);
return Array<Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
});

} // namespace relay
} // namespace tvm
48 changes: 37 additions & 11 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
""" Support level5 operator test cases.
"""
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing

def test_resize_infer_type():
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
Expand Down Expand Up @@ -70,22 +73,45 @@ def test_nms():
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(
(n, num_anchors, 6), "float32")
def test_yolo_reorg():


def test_yolo_reorg_infer_shape():
def verify_yolo_reorg(shape, stride, out_shape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.vision.yolo_reorg(x, stride=stride)
zz = relay.ir_pass.infer_type(z)
assert "stride=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")

n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, 20, 20), "float32"))
z = relay.vision.yolo_reorg(x, stride=10)
zz = relay.ir_pass.infer_type(z)
assert "stride=10" in z.astext()
assert zz.checked_type == relay.ty.TensorType((n, c*10*10, 2, 2), "float32")
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))

x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
z = relay.vision.yolo_reorg(x, stride=2)
assert "stride=2" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, c*2*2, h/2, w/2), "float32")
def test_yolo_reorg():
def verify_yolo_reorg(shape, stride):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = topi.testing.reorg_python(x_data, stride)

x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.vision.yolo_reorg(x, stride=stride)
zz = relay.ir_pass.infer_type(z)
assert "stride=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")

func = relay.Function([x], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_yolo_reorg((1, 100, 20, 20), 10)
verify_yolo_reorg((1, 4, 6, 6), 2)

if __name__ == "__main__":
test_resize_infer_type()
test_multibox_prior()
test_nms()
test_yolo_reorg_infer_shape()
test_yolo_reorg()

0 comments on commit 4ed21d3

Please sign in to comment.