Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bug in Bilinear Interpolation and Add Deform Conv to PT FrontEnd #7397

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 55 additions & 40 deletions include/tvm/topi/detail/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/te/operation.h>

#include <vector>
namespace tvm {
namespace topi {
namespace detail {
Expand Down Expand Up @@ -64,29 +65,36 @@ inline bool is_empty_shape(const Array<PrimExpr>& x) {
*/
inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
const PrimExpr max_y, const PrimExpr max_x) {
auto batch_id = indices[0];
auto channel_id = indices[1];
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));

auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;

auto in_x = indices[3];
auto xf = tvm::floor(in_x);
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));

auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;

auto A = input(indices[0], indices[1], y0, x0);
auto B = input(indices[0], indices[1], y0, x1);
auto C = input(indices[0], indices[1], y1, x0);
auto D = input(indices[0], indices[1], y1, x1);

return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
D * x_lerp * y_lerp;
auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y_high = y_low + 1;

auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x_high = x_low + 1;

auto wy_h = in_y - y_low;
auto wx_h = in_x - x_low;
auto wy_l = 1 - wy_h;
auto wx_l = 1 - wx_h;

PrimExpr val = 0;
std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
for (auto wx_xp_ele : wx_xp) {
for (auto wy_yp_ele : wy_yp) {
auto wx = wx_xp_ele[0];
auto xp = wx_xp_ele[1];
auto wy = wy_yp_ele[0];
auto yp = wy_yp_ele[1];
val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
wx * wy * input(batch_id, channel_id, yp, xp), 0);
}
}
return val;
}

/*!
Expand All @@ -101,29 +109,36 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
*/
inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
const PrimExpr max_y, const PrimExpr max_x) {
auto batch_id = indices[0];
auto channel_id = indices[3];
auto in_y = indices[1];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));

auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;

auto in_x = indices[2];
auto xf = tvm::floor(in_x);
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));

auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;

auto A = input(indices[0], y0, x0, indices[3]);
auto B = input(indices[0], y0, x1, indices[3]);
auto C = input(indices[0], y1, x0, indices[3]);
auto D = input(indices[0], y1, x1, indices[3]);

return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
D * x_lerp * y_lerp;
auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y_high = y_low + 1;

auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x_high = x_low + 1;

auto wy_h = in_y - y_low;
auto wx_h = in_x - x_low;
auto wy_l = 1 - wy_h;
auto wx_l = 1 - wx_h;

PrimExpr val = 0;
std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
for (auto wx_xp_ele : wx_xp) {
for (auto wy_yp_ele : wy_yp) {
auto wx = wx_xp_ele[0];
auto xp = wx_xp_ele[1];
auto wy = wy_yp_ele[0];
auto yp = wy_yp_ele[1];
val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
wx * wy * input(batch_id, yp, xp, channel_id), 0);
}
}
return val;
}

} // namespace detail
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,32 @@ def roi_align(self, inputs, input_types):

return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio)

def deform_conv2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
offset = inputs[2]
strides = (inputs[4], inputs[5])
padding = (inputs[6], inputs[7])
dilation = (inputs[8], inputs[9])
groups = inputs[10]
deformable_groups = inputs[11]
weight_shape = self.infer_shape(weight)
output_channels = weight_shape[0]
kernel_size = (weight_shape[2], weight_shape[3])

return _op.nn.deformable_conv2d(
data,
offset,
weight,
strides,
padding,
dilation,
deformable_groups,
groups,
output_channels,
kernel_size,
)

def unbind(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
Expand Down Expand Up @@ -2292,6 +2318,7 @@ def create_convert_map(self):
"torchvision::nms": self.nms,
"aten::logsumexp": self.logsumexp,
"torchvision::roi_align": self.roi_align,
"torchvision::deform_conv2d": self.deform_conv2d,
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
"aten::logical_and": self.logical_and,
Expand Down
26 changes: 17 additions & 9 deletions python/tvm/topi/testing/deformable_conv2d_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Deformable convolution in python"""
import itertools
import math
import numpy as np
from tvm.topi.nn.utils import get_pad_tuple

Expand Down Expand Up @@ -80,15 +81,22 @@ def deformable_conv2d_nchw_python(
dilation_h, dilation_w = dilation

def _bilinear(n, c, h, w):
low_h, low_w = int(h), int(w)
high_h = min(low_h + 1, in_height - 1)
high_w = min(low_w + 1, in_width - 1)
y_lerp = h - low_h
x_lerp = w - low_w

bottom = (1 - x_lerp) * a_np[n, c, low_h, low_w] + x_lerp * a_np[n, c, low_h, high_w]
top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w]
return (1 - y_lerp) * bottom + y_lerp * top
y_low = int(math.floor(h))
x_low = int(math.floor(w))
y_high = y_low + 1
x_high = x_low + 1

wy_h = h - y_low
wx_h = w - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h

val = 0
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < in_height and 0 <= xp < in_width:
val += wx * wy * a_np[n, c, yp, xp]
return val

a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype)
for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)):
Expand Down
34 changes: 19 additions & 15 deletions python/tvm/topi/testing/roi_align_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,29 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
else:
pooled_size_h, pooled_size_w = pooled_size

def _bilinear(b, c, y, x):
def _bilinear(n, c, y, x):
if y < -1 or y > height or x < -1 or x > width:
return 0
y = max(y, 0.0)
x = max(x, 0.0)
y_low = int(y)
x_low = int(x)

y_high = min(y_low + 1, height - 1)
x_high = min(x_low + 1, width - 1)
y = min(max(y, 0), height - 1)
x = min(max(x, 0), width - 1)

ly = y - y_low
lx = x - x_low
return (
(1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low]
+ (1 - ly) * lx * a_np[b, c, y_low, x_high]
+ ly * (1 - lx) * a_np[b, c, y_high, x_low]
+ ly * lx * a_np[b, c, y_high, x_high]
)
y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1

wy_h = y - y_low
wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h

val = 0
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
val += wx * wy * a_np[n, c, yp, xp]
return val

for i in range(num_roi):
roi = rois_np[i]
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/vision/rcnn/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):

def _bilinear(i, c, y, x):
outside = tvm.tir.any(y < -1.0, x < -1.0, y > height, x > width)
y = tvm.te.max(y, 0.0)
x = tvm.te.max(x, 0.0)
y = tvm.te.min(tvm.te.max(y, 0.0), height - 1)
x = tvm.te.min(tvm.te.max(x, 0.0), width - 1)
val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1)
return tvm.tir.if_then_else(outside, 0.0, val)

Expand Down
88 changes: 83 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at

assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)

del model_name
del baseline_model
torch.cuda.empty_cache()
Expand Down Expand Up @@ -924,6 +923,85 @@ def test_forward_conv_transpose():
verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)


def test_forward_deform_conv():
torch.set_grad_enabled(False)

def test_run(
batch_size,
in_channels,
out_channels,
in_height,
in_width,
out_height,
out_width,
offset_groups,
kh,
kw,
groups,
):
input_shape = [batch_size, in_channels, in_height, in_width]
offset_shape = [batch_size, 2 * offset_groups * kh * kw, out_height, out_width]
weight_shape = [out_channels, in_channels // groups, kh, kw]
input_data = torch.rand(input_shape)
offset_data = torch.rand(offset_shape)
weight_data = torch.rand(weight_shape)

class DeformConv2D(Module):
def forward(self, *args):
return torchvision.ops.deform_conv2d(args[0], args[1], args[2])

verify_model(
DeformConv2D().float().eval(),
input_data=[input_data, offset_data, weight_data],
rtol=1e-4,
atol=1e-4,
)

batch_size = 4
in_channels, out_channels = 4, 6
in_height, in_width = 10, 10
out_height, out_width = 8, 8
offset_groups = 2
kh, kw = 3, 3
groups = 1

test_run(
batch_size,
in_channels,
out_channels,
in_height,
in_width,
out_height,
out_width,
offset_groups,
kh,
kw,
groups,
)

batch_size = 5
in_channels, out_channels = 4, 6
in_height, in_width = 10, 10
out_height, out_width = 8, 8
offset_groups = 1
kh, kw = 3, 3
groups = 1

test_run(
batch_size,
in_channels,
out_channels,
in_height,
in_width,
out_height,
out_width,
offset_groups,
kh,
kw,
groups,
)


@tvm.testing.uses_gpu
def test_forward_threshold():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1700,7 +1778,7 @@ def test_forward_roi_align():
"""ROI align"""
torch.set_grad_enabled(False)

class ROIAlgin(Module):
class ROIAlign(Module):
def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1):
super().__init__()
self.spatial_scale = spatial_scale
Expand All @@ -1721,9 +1799,9 @@ def forward(self, *args):
in_batch = torch.zeros((35, 1), dtype=torch.float)
in_boxes = torch.cat([in_batch, in_boxes], dim=1)

verify_model(ROIAlgin(7), [in_data, in_boxes])
verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes])
verify_model(ROIAlgin(15, 0.9, 3), [in_data, in_boxes])
verify_model(ROIAlign(7), [in_data, in_boxes])
verify_model(ROIAlign((10, 10), 0.7, 5), [in_data, in_boxes])
verify_model(ROIAlign(15, 0.9, 3), [in_data, in_boxes])


@tvm.testing.uses_gpu
Expand Down
Loading