Skip to content

Commit

Permalink
Split adaptive_pool2d_avg into sum and div
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Oct 23, 2019
1 parent 5408d3a commit 6bb4e25
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def round(x):
Parameters
----------
x : Expr
x : Expr
Input argument.
Expand Down
18 changes: 15 additions & 3 deletions topi/include/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*)
}, "tensor", "adaptive_pool_max");
} else if (pool_type == kAvgPool) {
return tvm::compute(out_shape, [&](const Array<Var>& output) {
auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
Expand All @@ -505,8 +505,20 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
indices.Set(height_axis, i_start_h + dheight);
indices.Set(width_axis, i_start_w + dwidth);
return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
}, "tensor", "adaptive_pool_avg");
return tvm::sum(x(indices), { dheight, dwidth });
}, "tensor", "adaptive_pool_sum");

return tvm::compute(out_shape, [&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
auto i_end_h = end_index(output[height_axis], out_height, height);
auto i_start_w = start_index(output[width_axis], out_width, width);
auto i_end_w = end_index(output[width_axis], out_width, width);
auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h)
* (i_end_w - i_start_w));
return div(pool_sum(indices), divide_factor);
}, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
Expand Down
5 changes: 5 additions & 0 deletions topi/python/topi/x86/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def traverse(OP):
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('adaptive_pool'):
if OP != outs[0].op:
output = outs[0]
output_fused = s[output].fuse(output.op.axis[0], output.op.axis[1])
s[output].parallel(output_fused)

Pool = OP.output(0)
_parallel_sch(s[Pool], outs[0].shape)
else:
Expand Down

0 comments on commit 6bb4e25

Please sign in to comment.