diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index bf3a16cdb23e..399a0508fd40 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -310,12 +310,19 @@ def _darknet_region(inputs, attrs): def _darknet_yolo(inputs, attrs): """Process the yolo operation.""" - op_name, new_attrs = 'yolov3_yolo', {} - if 'n' in attrs: - new_attrs['n'] = attrs.get('n', 1) - if 'classes' in attrs: - new_attrs['classes'] = attrs.get('classes', 1) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + input_shape = attrs.get('shape') + split_size = classes + 5 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = _sym.reshape(inputs[0], shape=intermediate_shape) + split_indices = (2, 4) + split_res = _sym.split(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = _sym.sigmoid(split_res[0]) + split_res2 = _sym.sigmoid(split_res[2]) + concat_list = [split_res0, split_res[1], split_res2] + out = _sym.concatenate(*concat_list, axis=2) + return _sym.reshape(out, shape=input_shape), None def _darknet_activations(inputs, attrs): """Process the activation function.""" @@ -627,6 +634,7 @@ def _get_darknet_attrs(self, layer, layer_num): elif LAYERTYPE.YOLO == layer.type: attr.update({'n' : layer.n}) attr.update({'classes' : layer.classes}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) elif LAYERTYPE.UPSAMPLE == layer.type: attr.update({'scale' : layer.stride}) diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index e59b2bdfe6d9..f2e12c0f367a 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -38,21 +38,6 @@ def schedule_region(attrs, outs, target): reg.register_pattern("yolo_region", OpPattern.OPAQUE) -@reg.register_compute("yolov3_yolo") -def compute_yolo(attrs, inputs, _): - """Compute definition of yolo""" - n = attrs.get_int("n") - classes = attrs.get_int("classes") - return topi.vision.yolo.yolo(inputs[0], n, classes) - -@reg.register_schedule("yolov3_yolo") -def schedule_yolo(attrs, outs, target): - """Schedule definition of yolo""" - with tvm.target.create(target): - return topi.generic.schedule_injective(outs) - -reg.register_pattern("yolov3_yolo", OpPattern.OPAQUE) - # multibox_prior @reg.register_schedule("multibox_prior") def schedule_multibox_prior(_, outs, target): diff --git a/nnvm/src/top/vision/yolo/yolo.cc b/nnvm/src/top/vision/yolo/yolo.cc deleted file mode 100644 index 4800f4371f9d..000000000000 --- a/nnvm/src/top/vision/yolo/yolo.cc +++ /dev/null @@ -1,33 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file yolo.cc - * \brief Property def of yolo operators. - */ -#include -#include -#include -#include -#include "../../elemwise_op_common.h" - -namespace nnvm { -namespace top { - -NNVM_REGISTER_OP(yolov3_yolo) -.describe(R"code(Yolo layer -)code" NNVM_ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_support_level(5) -.add_argument("data", "Tensor", "Input data") -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr( - "FInplaceOption", - [](const NodeAttrs &attrs) { - return std::vector>{{0, 0}, {1, 0}}; - }) -.set_attr("FGradient", [](const NodePtr &n, - const std::vector &ograds) { - return std::vector{ograds[0], ograds[0]}; -}); -} // namespace top -} // namespace nnvm diff --git a/topi/include/topi/vision/yolo/yolo.h b/topi/include/topi/vision/yolo/yolo.h deleted file mode 100644 index d2e24c01b253..000000000000 --- a/topi/include/topi/vision/yolo/yolo.h +++ /dev/null @@ -1,58 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \brief YOLO op constructions - * \file vision/yolo/yolo.h - */ -#ifndef TOPI_VISION_YOLO_YOLO_H_ -#define TOPI_VISION_YOLO_YOLO_H_ - -#include -#include - -#include "topi/detail/constant_utils.h" -#include "topi/tags.h" -#include "topi/transform.h" -#include "tvm/tvm.h" - - -namespace topi { -namespace vision { -namespace yolo { -using namespace tvm; -using namespace nn; - -/*! -* \brief yolo operation -* -* \param data The input tensor. -* \param num Darknet layer parameter n -* \param classes number of classes in the yolo model -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the yolo operation -*/ -inline Tensor yolo(const Tensor &data, - int num, - int classes, - std::string name = "tensor", - std::string tag = "yolo_output") { - auto input_shape = data->shape; - int split_size = classes + 5; - Array intermediate_shape = {input_shape[0], - num, - split_size, - input_shape[2], - input_shape[3]}; - auto data_block = reshape(data, intermediate_shape); - Array split_indices = {2, 4}; - Array split_res = split(data_block, split_indices, 2); - split_res.Set(0, sigmoid(split_res[0])); - split_res.Set(2, sigmoid(split_res[2])); - Tensor out = concatenate(split_res, 2); - return reshape(out, input_shape); -} -} // namespace yolo -} // namespace vision -} // namespace topi -#endif // TOPI_VISION_YOLO_YOLO_H_ diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index c9d995a38686..c91eea7958ea 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -15,7 +15,6 @@ from .bilinear_resize_python import bilinear_resize_python from .reorg_python import reorg_python from .region_python import region_python -from .yolo_python import yolo_python from .shortcut_python import shortcut_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python diff --git a/topi/python/topi/testing/yolo_python.py b/topi/python/topi/testing/yolo_python.py deleted file mode 100644 index a6b3a41203c6..000000000000 --- a/topi/python/topi/testing/yolo_python.py +++ /dev/null @@ -1,43 +0,0 @@ -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Yolo operator in python""" -import numpy as np - -def entry_index(batch, w, h, outputs, classes, coords, location, entry): - n = int(location/(w*h)) - loc = location%(w*h) - return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc - -def yolo_python(a_np, N, classes): - """Yolo operator - Parameters - ---------- - a_np : numpy.ndarray - 4-D with shape [batch, in_channel, in_height, in_width] - - N : int - Darknet layer parameter n - - classes : int - Darknet layer parameter classes - - Returns - ------- - b_np : np.ndarray - 4-D with shape [batch, out_channel, out_height, out_width] - """ - - batch, in_channel, in_height, in_width = a_np.shape - a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width) - outputs = batch*in_channel*in_height*in_width - b_np = np.zeros(batch*in_channel*in_height*in_width) - for i in range(batch*in_channel*in_height*in_width): - b_np[i] = a_np_temp[i] - for b in range(batch): - for n in range(N): - index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 0) - b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height])) - index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 4) - b_np[index: index+(1+classes)*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+(1+classes)*in_width*in_height])) - - b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width)) - return b_np diff --git a/topi/python/topi/vision/yolo/__init__.py b/topi/python/topi/vision/yolo/__init__.py index 2c0a165f8aac..c0e9899a41aa 100644 --- a/topi/python/topi/vision/yolo/__init__.py +++ b/topi/python/topi/vision/yolo/__init__.py @@ -3,4 +3,3 @@ from __future__ import absolute_import as _abs from .region import * -from .yolo import * diff --git a/topi/python/topi/vision/yolo/yolo.py b/topi/python/topi/vision/yolo/yolo.py deleted file mode 100644 index 6ae630a86d8f..000000000000 --- a/topi/python/topi/vision/yolo/yolo.py +++ /dev/null @@ -1,30 +0,0 @@ -# pylint: disable=invalid-name, unused-variable -""" -YOLO Operator -============= -YOLO operator, used in darknet. -""" -from __future__ import absolute_import as _abs -import tvm -from ... import cpp - -@tvm.target.generic_func -def yolo(data, num, classes): - """YOLO forward operators. - Parameters - ---------- - data : tvm.Tensor - 4-D with shape [batch, c_in, h_in, w_in] - - num : int - Darknet layer parameter n - - classes : int - Darknet layer parameter classes - - Returns - ------- - out : tvm.Tensor - 4-D with shape [batch, c_in, h_in, w_in] - """ - return cpp.yolo.yolo(data, num, classes) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index ae1ad57551cb..2d9f2fd6c6b2 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -32,7 +32,6 @@ #include #include #include -#include #include #include #include @@ -413,11 +412,6 @@ TVM_REGISTER_GLOBAL("topi.vision.yolo.region") *rv = vision::yolo::region(args[0], args[1], args[2], args[3], args[4], args[5]); }); -TVM_REGISTER_GLOBAL("topi.vision.yolo.yolo") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = vision::yolo::yolo(args[0], args[1], args[2]); - }); - /* Ops from image/resize.h */ TVM_REGISTER_GLOBAL("topi.image.resize") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python_cpp/test_topi_yolo.py b/topi/tests/python_cpp/test_topi_yolo.py deleted file mode 100644 index 293de4fca087..000000000000 --- a/topi/tests/python_cpp/test_topi_yolo.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Test code for yolo op""" -import logging -import numpy as np -import tvm -import topi -import topi.testing -from topi.util import get_const_tuple - -def verify_yolo(ishape, n, classes): - '''Verify yolo operator by comparing outputs from tvm and numpy implementation''' - - A = tvm.placeholder(ishape, name='A') - B = topi.cpp.yolo.yolo(A, n, classes) - dtype = A.dtype - - def get_ref_data_yolo(): - '''Randomly initialize the data variables and get refernce output for the yolo operation''' - a_np = np.random.uniform(size=ishape).astype(dtype) - b_np = topi.testing.yolo_python(a_np, n, classes) - return a_np, b_np - - a_np, b_np = get_ref_data_yolo() - def check_device(device): - '''Check the device is available and if so, build and run the program''' - if not tvm.module.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - target = topi.cpp.TEST_create_target(device) - if device == "llvm": - s = topi.cpp.generic.default_schedule(target, [B], False) - else: - s = topi.cpp.cuda.schedule_injective(target, [B]) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - func = tvm.build(s, [A, B], device, name="yolo") - func(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']: - check_device(device) - -def test_yolo(): - verify_yolo((1, 425, 19, 19), 5, 80) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_yolo()