diff --git a/topi/include/topi/nn/batch_norm.h b/topi/include/topi/nn/batch_norm.h deleted file mode 100644 index be3e31d216d0..000000000000 --- a/topi/include/topi/nn/batch_norm.h +++ /dev/null @@ -1,65 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \brief Batch normalization op constructions - * \file nn/batch_norm.h - */ -#ifndef TOPI_NN_BATCH_NORM_H_ -#define TOPI_NN_BATCH_NORM_H_ - -#include - -#include "topi/tags.h" -#include "tvm/tvm.h" - -namespace topi { -namespace nn { -using namespace tvm; - -/*! -* \brief Batch normalization inference operator with NCHW layout -* -* \param x The input tensor. 4-D with shape [batch, channel, height, width] -* \param gamma 1-D with shape [channel] -* \param beta 1-D with shape [channel] -* \param moving_mean 1-D with shape [channel] -* \param moving_var 1-D with shape [channel] -* \param eps Epsilon to prevent div by 0 -* \param fix_gamma Fix gamma while training -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the batch normalization operation -*/ -inline Tensor batch_norm_inference(const Tensor& x, - const Tensor& gamma, - const Tensor& beta, - const Tensor& moving_mean, - const Tensor& moving_var, - float eps, - bool fix_gamma, - std::string name = "tensor", - std::string tag = kBroadcast) { - CHECK_EQ(x->shape.size(), 4) << "Batch norm requires 4-D input"; - - Tensor out; - if (fix_gamma) { - out = tvm::compute( - x->shape, - [&](const Array& indices) { - auto c = Array({ indices[1] }); - return (x(indices) - moving_mean(c)) / tvm::sqrt(moving_var(c) + eps) + beta(c); - }, name, tag); - } else { - out = tvm::compute( - x->shape, - [&](const Array& indices) { - auto c = Array({ indices[1] }); - return (x(indices) - moving_mean(c)) / tvm::sqrt(moving_var(c) + eps) * gamma(c) + beta(c); - }, name, tag); - } - return out; -} - -} // namespace nn -} // namespace topi -#endif // TOPI_NN_BATCH_NORM_H_ diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 690379135e06..cfb9e566279a 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -2,7 +2,6 @@ """Neural network operators""" from __future__ import absolute_import as _abs -from .batch_norm import * from .conv2d import * from .depthwise_conv2d import * from .elemwise import * diff --git a/topi/python/topi/nn/batch_norm.py b/topi/python/topi/nn/batch_norm.py deleted file mode 100644 index 551a6280b312..000000000000 --- a/topi/python/topi/nn/batch_norm.py +++ /dev/null @@ -1,56 +0,0 @@ -"""TVM operator batch normalization compute.""" -from __future__ import absolute_import -import tvm -from .. import tag - -@tvm.tag_scope(tag=tag.BROADCAST) -def batch_norm_inference(data, gamma, beta, moving_mean, moving_var, eps, fix_gamma): - """Batch normalization inference operator in NCHW layout. - - Parameters - ---------- - data : tvm.Tensor - 4-D with shape [batch, channel, height, width] - - gamma : tvm.Tensor - 1-D with shape [channel] - - beta : tvm.Tensor - 1-D with shape [channel] - - moving_mean : tvm.Tensor - 1-D with shape [channel] - - moving_var : tvm.Tensor - 1-D with shape [channel] - - eps : float - Epsilon to prevent div 0. - - fix_gamma : boolean - Fix gamma while training - - Returns - ------- - output : tvm.Tensor - 4-D with shape [batch, channel, height, width] - - mean : tvm.Tensor - 1-D with shape [channel] - - var : tvm.Tensor - 1-D with shape [channel] - """ - assert len(data.shape) == 4, "only support 4-dim batch norm" - batch, channel, height, width = data.shape - if fix_gamma: - out = tvm.compute((batch, channel, height, width), \ - lambda b, c, h, w: (data[b, c, h, w] - moving_mean[c]) / \ - tvm.intrin.sqrt(moving_var[c] + eps) + beta[c]) - else: - out = tvm.compute((batch, channel, height, width), \ - lambda b, c, h, w: (data[b, c, h, w] - moving_mean[c]) / \ - tvm.intrin.sqrt(moving_var[c] + eps) * gamma[c] + beta[c]) - mean = tvm.compute((C, ), lambda c: moving_mean[c]) - var = tvm.compute((C, ), lambda c: moving_var[c]) - return [out, mean, var] diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 7adcb11c5656..d56174fda5c5 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -17,7 +17,6 @@ #include #include -#include #include #include #include @@ -328,18 +327,6 @@ TVM_REGISTER_GLOBAL("topi.nn.upsampling") *rv = nn::upsampling(args[0], args[1], args[2], args[3]); }); -/* Ops from nn/batch_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::batch_norm_inference(args[0], - args[1], - args[2], - args[3], - args[4], - static_cast(args[5]), - args[6]); - }); - /* Ops from nn/bnn.h */ TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") .set_body([](TVMArgs args, TVMRetValue *rv) {