Skip to content

Commit

Permalink
Add schedule for conv3d NDHWC layout (apache#4775)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github authored and alexwong committed Feb 28, 2020
1 parent 5024fd5 commit af1cb07
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
14 changes: 7 additions & 7 deletions topi/python/topi/nn/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
pad_before = [0, pad_front, pad_top, pad_left, 0]
pad_after = [0, pad_back, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rd = tvm.reduce_axis((0, kernel_d), name='rd')
rh = tvm.reduce_axis((0, kernel_h), name='rh')
rw = tvm.reduce_axis((0, kernel_w), name='rw')
rc = tvm.reduce_axis((0, in_channel), name='rc')
rz = tvm.reduce_axis((0, kernel_d), name='rz')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(batch, out_depth, out_height, out_width, out_channel),
lambda nn, zz, yy, xx, ff: tvm.sum(
PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]),
lambda nn, dd, hh, ww, cc: tvm.sum(
PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h,
ww * stride_w + rw * dilation_w, rc].astype(out_dtype) *
Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
name="Conv3dOutput", tag="conv3d_ndhwc")
return Output
1 change: 1 addition & 0 deletions topi/python/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .conv1d import schedule_conv1d_nwc
from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from .conv3d import schedule_conv3d_ndhwc
from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense
from .nn import *
Expand Down
82 changes: 82 additions & 0 deletions topi/python/topi/x86/conv3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators"""
import tvm
from .. import generic, tag
from ..util import traverse_inline

@generic.schedule_conv3d_ndhwc.register("cpu")
def schedule_conv3d_ndhwc(outs):
"""TOPI schedule callback for conv3d
Parameters
----------
outs: Array of Tensor
The computation graph description of conv3d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv3d.
"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op

def _traverse(op):
"""Traverse operators from computation graph"""
if op in s.outputs and tag.is_broadcast(op.tag) and len(op.axis) == 5:
# schedule bias + bn + relu
n, d, h, w, c = op.axis
fused = s[op].fuse(n, d, h, w)
s[op].parallel(fused)
s[op].vectorize(c)

if 'conv3d_ndhwc' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
# dilation stage
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

# padding stage
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
# fuse pad h and w
data_pad = data
data = data_pad.op.input_tensors[0]
_, _, h_pad, w_pad, _ = data_pad.op.axis
pad_fused = s[data_pad].fuse(h_pad, w_pad)
s[data_pad].parallel(pad_fused)

# compute conv
C = conv
n, d, h, w, c = s[C].op.axis
s[C].vectorize(c)
if op != output_op: # fuse bias + bn + activation
_, _, _, _, c_out = output_op.axis
s[C].compute_at(s[output_op], c_out)
else:
# fuse batch, depth, height axes
fused = s[C].fuse(n, d, h)
s[C].parallel(fused)

traverse_inline(s, output_op, _traverse)
return s

0 comments on commit af1cb07

Please sign in to comment.