Skip to content

Commit

Permalink
Add schedule for conv3d NDHWC layout
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github committed Jan 28, 2020
1 parent 9bd2c7b commit 3546e95
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 7 deletions.
14 changes: 7 additions & 7 deletions topi/python/topi/nn/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
pad_before = [0, pad_front, pad_top, pad_left, 0]
pad_after = [0, pad_back, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rd = tvm.reduce_axis((0, kernel_d), name='rd')
rh = tvm.reduce_axis((0, kernel_h), name='rh')
rw = tvm.reduce_axis((0, kernel_w), name='rw')
rc = tvm.reduce_axis((0, in_channel), name='rc')
rz = tvm.reduce_axis((0, kernel_d), name='rz')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(batch, out_depth, out_height, out_width, out_channel),
lambda nn, zz, yy, xx, ff: tvm.sum(
PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]),
lambda nn, dd, hh, ww, cc: tvm.sum(
PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h,
ww * stride_w + rw * dilation_w, rc].astype(out_dtype) *
Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
name="Conv3dOutput", tag="conv3d_ndhwc")
return Output
1 change: 1 addition & 0 deletions topi/python/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .conv1d import schedule_conv1d_nwc
from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from .conv3d import schedule_conv3d_ndhwc
from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense
from .nn import *
Expand Down
88 changes: 88 additions & 0 deletions topi/python/topi/x86/conv3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators"""
import tvm
from .. import generic, tag

@generic.schedule_conv3d_ndhwc.register("cpu")
def schedule_conv3d_ndhwc(outs):
"""TOPI schedule callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []

def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 5: # schedule bias + bn + relu
n, d, h, w, c = op.axis
fused = s[op].fuse(n, d, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv3d_ndhwc' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, d_pad, h_pad, w_pad, c_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(h_pad, w_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, d, h, w, c = s[C].op.axis
s[C].vectorize(c)
if op != output_op:
n_out, d_out, h_out, w_out, c_out = output_op.axis
s[C].compute_at(s[output_op], c_out)
else:
fused = s[C].fuse(n, d, h)
s[C].parallel(fused)

scheduled_ops.append(op)

traverse(output_op)
return s

0 comments on commit 3546e95

Please sign in to comment.