Skip to content

Commit

Permalink
[TOPI] deformable_conv2d in NHWC
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Nov 30, 2020
1 parent 408f4ff commit ade70a5
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 4 deletions.
37 changes: 37 additions & 0 deletions include/tvm/topi/detail/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,43 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
D * x_lerp * y_lerp;
}

/*!
* \brief Sample a point in a tensor using bilinear interpolation.
*
* \param input The input tensor.
* \param indices The index of the target point, which can be fractional
* \param max_y The maximum of y dimension
* \param max_x The maximum of x dimension
*
* \return The interpolated value in the given index.
*/
inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
const PrimExpr max_y, const PrimExpr max_x) {
auto in_y = indices[1];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));

auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;

auto in_x = indices[2];
auto xf = tvm::floor(in_x);
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));

auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;

auto A = input(indices[0], indices[3], y0, x0);
auto B = input(indices[0], indices[3], y0, x1);
auto C = input(indices[0], indices[3], y1, x0);
auto D = input(indices[0], indices[3], y1, x1);

return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
D * x_lerp * y_lerp;
}

} // namespace detail
} // namespace topi
} // namespace tvm
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,23 @@ def schedule_deformable_conv2d_nchw(outs):
return _default_schedule(outs, False)


def schedule_deformable_conv2d_nhwc(outs):
"""Schedule for deformable_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of deformable_conv2d_nhwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
Expand Down
104 changes: 103 additions & 1 deletion python/tvm/topi/nn/deformable_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .utils import get_pad_tuple
from ..utils import get_const_tuple
from ..cpp.utils import bilinear_sample_nchw
from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc


def deformable_conv2d_nchw(
Expand Down Expand Up @@ -130,3 +130,105 @@ def _bilinear(n, c, h, w):
),
tag="deformable_conv2d_nchw",
)


def deformable_conv2d_nhwc(
data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
):
"""Deformable conv2D operator in NHWC layout.
The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
Parameters
----------
data : tvm.te.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
offset : tvm.te.Tensor
4-D with shape [batch, out_height, out_width,
deformable_groups * filter_height * filter_width * 2].
kernel : tvm.te.Tensor
4-D with shape [filter_height, filter_width, in_channel, num_filter]
strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation : int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
deformable_groups : int
number of deformable groups
groups : int
number of groups
Returns
-------
output : tvm.te.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
if out_dtype is None:
out_dtype = data.dtype

if isinstance(strides, int):
stride_h = stride_w = strides
else:
stride_h, stride_w = strides

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

batch, in_height, in_width, in_channel = get_const_tuple(data.shape)
kernel_h, kernel_w, channel, out_channel = get_const_tuple(kernel.shape)
_, out_height, out_width, _ = get_const_tuple(offset.shape)
assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size"
assert groups == 1, "deformable_conv2d_nchw does not support groups > 1"

ic_per_dgroup = channel // deformable_groups

dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")

zero = tvm.tir.const(0.0, data.dtype)

def _bilinear(n, h, w, c):
outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width)
val = bilinear_sample_nhwc(data, (n, h, w, c), in_height - 1, in_width - 1)
return tvm.tir.if_then_else(outside, zero, val)

data_deform = te.compute(
(batch, kernel_h, kernel_w, in_channel, out_height, out_width),
lambda n, kh, kw, c, y, x: _bilinear(
n,
y * stride_h - pad_top + kh * dilation_h
+ offset[
n, y, x, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2
],
x * stride_w - pad_left + kw * dilation_w
+ offset[
n, y, x, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2
+ 1,
],
c
),
tag="data_deform",
)
return te.compute(
(batch, out_height, out_width, out_channel),
lambda n, y, x, f: te.sum(
data_deform[n, ry, rx, rc, y, x].astype(out_dtype)
* kernel[ry, rx, rc, f].astype(out_dtype),
axis=[ry, rx, rc],
),
tag="deformable_conv2d_nhwc",
)
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python
from .correlation_nchw_python import correlation_nchw_python
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,51 @@ def _bilinear(n, c, h, w):
b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c])

return b_np


def deformable_conv2d_nhwc_python(
a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
):
"""Deformable convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
offset_np : numpy.ndarray
4-D with shape [batch, out_height, out_width,
deformable_groups * filter_height * filter_width * 2]
w_np : numpy.ndarray
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
dilation : int or a list/tuple of two ints
Dilation size, or [dilate_height, dilate_width]
deformable_groups : int
Number of deformable groups
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_np = np.transpose(a_np, [0, 3, 1, 2]) # NHWC -> NCHW
offset_np = np.transpose(offset_np, [0, 3, 1, 2]) # NHWC -> NCHW
w_np = np.transpose(w_np, [3, 2, 0, 1]) # HWIO -> OIHW
b_np = deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation,
deformable_groups, groups)
b_np = np.transpose(b_np, [0, 2, 3, 1]) # NCHW -> NHWC
return b_np
4 changes: 4 additions & 0 deletions src/topi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw").set_body([](TVMArgs args,
*rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
});

TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = detail::bilinear_sample_nhwc(args[0], args[1], args[2], args[3]);
});

/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
Expand Down
90 changes: 88 additions & 2 deletions tests/python/topi/python/test_topi_deformable_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
import tvm.testing


_deformable_conv2d_implement = {
_deformable_conv2d_nchw_implement = {
"generic": (topi.nn.deformable_conv2d_nchw, topi.generic.schedule_deformable_conv2d_nchw),
"cuda": (topi.cuda.deformable_conv2d_nchw, topi.cuda.schedule_deformable_conv2d_nchw),
}

_deformable_conv2d_nhwc_implement = {
"generic": (topi.nn.deformable_conv2d_nhwc, topi.generic.schedule_deformable_conv2d_nhwc),
}

def verify_deformable_conv2d_nchw(
batch,
Expand Down Expand Up @@ -94,7 +97,7 @@ def check_device(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_implement)
fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nchw_implement)
with tvm.target.Target(device):
C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
s = fschedule([C])
Expand All @@ -112,11 +115,94 @@ def check_device(device):
check_device(device)


def verify_deformable_conv2d_nhwc(
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
dilation=1,
deformable_groups=1,
groups=1,
):
print(
"Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
% (
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
dilation,
deformable_groups,
groups,
)
)

A = te.placeholder((batch, in_size, in_size, in_channel), name="A")
out_size = (in_size - (kernel - 1) * dilation - 1 + 2 * padding) // stride + 1
Offset = te.placeholder(
(batch, out_size, out_size, deformable_groups * kernel * kernel * 2), name="offset"
)
W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
bias = te.placeholder((num_filter,), name="bias")

a_shape = get_const_tuple(A.shape)
offset_shape = get_const_tuple(Offset.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_deformable_conv2d_nchw.verify_deformable_conv2d_nhwc")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
offset_np = np.random.randn(*offset_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np = tvm.topi.testing.deformable_conv2d_nhwc_python(
a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
)

return a_np, offset_np, w_np, c_np

a_np, offset_np, w_np, c_np = get_ref_data()

def check_device(device):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nhwc_implement)
with tvm.target.Target(device):
C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
s = fschedule([C])

a = tvm.nd.array(a_np, ctx)
offset = tvm.nd.array(offset_np, ctx)
w = tvm.nd.array(w_np, ctx)
c = tvm.nd.empty(c_np.shape, dtype=c_np.dtype, ctx=ctx)

func = tvm.build(s, [A, Offset, W, C], device)
func(a, offset, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

for device in ["llvm"]:
check_device(device)


@tvm.testing.uses_gpu
def test_deformable_conv2d_nchw():
verify_deformable_conv2d_nchw(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4)
verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 2, dilation=2)
verify_deformable_conv2d_nhwc(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4)
verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 2, dilation=2)


if __name__ == "__main__":
Expand Down

0 comments on commit ade70a5

Please sign in to comment.