From f7c3201968b0c4bb6d9868b5d2c9d16911959d58 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 23 Oct 2019 15:59:50 -0700 Subject: [PATCH] Split adaptive_pool2d_avg into sum and div --- topi/include/topi/nn/pooling.h | 18 +++++++++++++++--- topi/python/topi/x86/pooling.py | 5 +++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 289452e26869..ca35e6e43498 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -492,7 +492,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*) }, "tensor", "adaptive_pool_max"); } else if (pool_type == kAvgPool) { - return tvm::compute(out_shape, [&](const Array& output) { + auto pool_sum = tvm::compute(out_shape, [&](const Array& output) { Array indices; for (const Var& var : output) indices.push_back(var); auto i_start_h = start_index(output[height_axis], out_height, height); @@ -505,8 +505,20 @@ inline Tensor adaptive_pool_impl(const Tensor& x, auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); indices.Set(height_axis, i_start_h + dheight); indices.Set(width_axis, i_start_w + dwidth); - return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth }); - }, "tensor", "adaptive_pool_avg"); + return tvm::sum(x(indices), { dheight, dwidth }); + }, "tensor", "adaptive_pool_sum"); + + return tvm::compute(out_shape, [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + auto i_start_h = start_index(output[height_axis], out_height, height); + auto i_end_h = end_index(output[height_axis], out_height, height); + auto i_start_w = start_index(output[width_axis], out_width, width); + auto i_end_w = end_index(output[width_axis], out_width, width); + auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h) + * (i_end_w - i_start_w)); + return div(pool_sum(indices), divide_factor); + }, "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; diff --git a/topi/python/topi/x86/pooling.py b/topi/python/topi/x86/pooling.py index ac19b19de28d..e9f832dde902 100644 --- a/topi/python/topi/x86/pooling.py +++ b/topi/python/topi/x86/pooling.py @@ -147,6 +147,11 @@ def traverse(OP): traverse(tensor.op) # schedule pool elif OP.tag.startswith('adaptive_pool'): + if OP != outs[0].op: + output = outs[0] + output_fused = s[output].fuse(output.op.axis[0], output.op.axis[1]) + s[output].parallel(output_fused) + Pool = OP.output(0) _parallel_sch(s[Pool], outs[0].shape) else: