-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9bd2c7b
commit 57f4722
Showing
3 changed files
with
101 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import tvm | ||
from tvm import autotvm | ||
from .. import generic, tag | ||
from ..nn.conv3d import conv3d, conv3d_ndhwc, conv3d_ncdhw | ||
from ..generic.nn import schedule_conv3d_ndhwc | ||
|
||
@autotvm.register_topi_compute(conv3d, 'cpu', ['direct']) | ||
def conv3d_x86(cfg, input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=None): | ||
if layout == 'NCDHW': | ||
return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype) | ||
elif layout == 'NDHWC': | ||
return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype) | ||
|
||
@autotvm.register_topi_schedule(schedule_conv3d_ndhwc, 'cpu', ['direct']) | ||
def schedule_conv3d_ndhwc_x86(cfg, outs): | ||
"""TOPI schedule callback for conv2d | ||
Parameters | ||
---------- | ||
cfg: ConfigEntity | ||
The config for this template | ||
outs: Array of Tensor | ||
The computation graph description of conv2d | ||
in the format of an array of tensors. | ||
Returns | ||
------- | ||
s: Schedule | ||
The computation schedule for conv2d. | ||
""" | ||
s = tvm.create_schedule([x.op for x in outs]) | ||
output_op = outs[0].op | ||
scheduled_ops = [] | ||
|
||
def traverse(op): | ||
"""Traverse operators from computation graph""" | ||
# inline all one-to-one-mapping operators except the last stage (output) | ||
if tag.is_broadcast(op.tag): | ||
if op not in s.outputs: | ||
s[op].compute_inline() | ||
else: # inject custom schedule | ||
if len(op.axis) == 5: # schedule bias + bn + relu | ||
n, d, h, w, c = op.axis | ||
fused = s[op].fuse(n, d, h, w) | ||
s[op].parallel(fused) | ||
s[op].vectorize(c) | ||
for tensor in op.input_tensors: | ||
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: | ||
traverse(tensor.op) | ||
|
||
if 'conv3d_ndhwc' in op.tag: | ||
conv = op.output(0) | ||
kernel = op.input_tensors[1] | ||
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: | ||
s[kernel].compute_inline() | ||
|
||
data = op.input_tensors[0] | ||
data_pad = None | ||
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: | ||
data_pad = data | ||
data = data_pad.op.input_tensors[0] | ||
n_pad, d_pad, h_pad, w_pad, c_pad = data_pad.op.axis | ||
pad_fused = s[data_pad].fuse(h_pad, w_pad) | ||
s[data_pad].parallel(pad_fused) | ||
|
||
C = conv | ||
# data axes | ||
n, d, h, w, c = s[C].op.axis | ||
|
||
if True: | ||
# tile data h and w | ||
ho, wo, hi, wi = s[C].tile(h, w, 2, 2) | ||
# kernel axes | ||
kd, ky, kx, kc = s[C].op.reduce_axis | ||
kxi, kxo = s[C].split(kx, factor=2) | ||
kci, kco = s[C].split(kc, factor=2) | ||
# | ||
s[C].reorder(n, d, ho, wo, hi, wi, c, kxo, kco, kxi, kci) | ||
s[C].unroll(kci) | ||
|
||
s[C].vectorize(c) | ||
if op != output_op: | ||
_, _, _, _, c_out = output_op.axis | ||
s[C].compute_at(s[output_op], c_out) | ||
else: | ||
fused = s[C].fuse(n, d) | ||
s[C].parallel(fused) | ||
|
||
scheduled_ops.append(op) | ||
|
||
traverse(output_op) | ||
return s |