From 4d23916dd396958524c07ea8805bfb56baa3d233 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sat, 10 Jul 2021 21:24:20 +0800 Subject: [PATCH 01/22] Add grid_sampler test functions. --- tests/python/frontend/pytorch/test_forward.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f76ea9a5d324..e1546f745d57 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3912,6 +3912,48 @@ def forward(self, x): verify_model(Flip(axis=-1), input_data=input) +@tvm.testing.uses_gpu +def test_forward_grid_sampler(): + torch.set_grad_enabled(False) + + class GridSampler(Module): + def __init__(self, output_h, output_w): + super(GridSampler, self).__init__() + self.output_h = output_h + self.output_w = output_w + + # normalize to [-1.0, 1.0] + h = torch.arange(0, output_h) / (output_h - 1.0) * 2.0 - 1.0 + w = torch.arange(0, output_w) / (output_w - 1.0) * 2.0 - 1.0 + grid = torch.zeros(output_h, output_w, 2) + grid[:, :, 0] = w.unsqueeze(0).repeat(output_h, 1) + grid[:, :, 1] = h.unsqueeze(0).repeat(output_w, 1).transpose(0, 1) + self.grid = grid.unsqueeze(0) + + def forward(self, input): + batch = input.size(0) + grid = self.grid.repeat(batch, 1, 1, 1).to(input.device) + + # Torch grid_sample default: mode='bilinear', padding_mode='zeros', align_corners=False + # tvm seems align corners as True + + # *********************************************************************************** + # + # !!! DO NOT USE !!! + # F.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True) + # for it broken TVM "if conditional expression" in torch script mode + # + # *********************************************************************************** + + return torch.grid_sampler(input, grid, 0, 0, True) + + model = GridSampler(16, 32) + input = torch.randn(2, 3, 32, 32) + + verify_model(model, input_data=input) + verify_script_model(model.eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4055,6 +4097,7 @@ def forward(self, x): test_hard_sigmoid() test_forward_nll_loss() test_forward_flip() + test_forward_grid_sampler() # Model tests test_resnet18() From 80b22628612299dfdc895e47c0deff1d9ce83d43 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sat, 10 Jul 2021 21:25:16 +0800 Subject: [PATCH 02/22] Support grid_sampler.py --- python/tvm/relay/frontend/pytorch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4c874672445b..a8816ca607cc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2329,6 +2329,15 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def grid_sampler(self, inputs, input_types): + data = inputs[0] + grid = inputs[1] + + # Torch grid shape is like [batch, out_height, out_width, 2], but + # TVM grid is [batch, 2, out_height, out_width], so here grid need to be converted + grid = _op.transform.transpose(grid, axes=[0, 3, 1, 2]) + return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW") + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2545,6 +2554,7 @@ def create_convert_map(self): "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, "aten::flip": self.flip, + "aten::grid_sampler": self.grid_sampler, } def update_convert_map(self, custom_map): From fe46820b3c46de905324d290b947f6d93f685211 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sat, 10 Jul 2021 21:55:01 +0800 Subject: [PATCH 03/22] Support aten::im2col --- include/tvm/relay/attrs/nn.h | 23 ++++ include/tvm/topi/nn.h | 125 ++++++++++++++++++ python/tvm/relay/frontend/pytorch.py | 15 +++ python/tvm/relay/op/nn/_nn.py | 3 + python/tvm/relay/op/nn/nn.py | 36 +++++ src/relay/op/nn/pad.cc | 85 ++++++++++++ tests/python/frontend/pytorch/test_forward.py | 38 ++++++ 7 files changed, 325 insertions(+) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 3c7574562676..6ccbce97e641 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1463,6 +1463,29 @@ struct NLLLossAttrs : public tvm::AttrsNode { } }; // struct NLLLossAttrs +/*! \brief Attributes used in Im2col operator */ +struct Im2colAttrs : public tvm::AttrsNode { + Array kernel_size; + Array dilation; + Array padding; + Array stride; + + TVM_DECLARE_ATTRS(Im2colAttrs, "relay.attrs.Im2colAttrs") { + TVM_ATTR_FIELD(kernel_size) + .set_default(Array({3, 3})) + .describe("The kernel size."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("The dilation size."); + TVM_ATTR_FIELD(padding) + .set_default(Array({1, 1})) + .describe("The padding size."); + TVM_ATTR_FIELD(stride) + .set_default(Array({1, 1})) + .describe("The strides."); + } +}; // struct Im2colAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 90c1c09a070b..95d56d514830 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -690,6 +690,131 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T return T; } } + +/*! + * \brief Creates an operation that performs im2col with an NCHW-layout + * + * \param data The 4-D input tensor + * \param kernel_size A static tuple for kernel size, such as (3,3) + * \param dilation A static tuple for dilation, default is (1,1) + * \param padding A static tuple for padding, padding value is zero + * \param stride A static tuple for strides, default is (1,1) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the im2col operation (NCHW layout) + + Pseudo code: + input_b, input_c, input_h, input_w = data_shape + dilation_h, dilation_w = dilation + padding_h, padding_w = padding + + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + + output_h = (input_h + 2 * padding_h - dilated_kernel_h)//stride_h + 1 + output_w = (input_w + 2 * padding_w - dilated_kernel_w)//stride_w + 1 + + h_offset = (k // input_w) % kernel_h + w_offset = k % kernel_w + + im2col_data = te.compute( + (N, K, L), + lambda n, k, l: data[ + n, + k / kernel_h / kernel_w, + stride_h * (l / output_w) + dilation_h * h_offset - padding_h, + stride_w * (l % output_w) + dilation_w * w_offset - padding_w, + ], + name="im2col_data", + ) + */ +inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, + const tvm::Array& kernel_size, + const tvm::Array& dilation, + const tvm::Array& padding, + const tvm::Array& stride, + std::string name = "T_im2col", + std::string tag = kElementWise) { // ElementWise + ICHECK_EQ(4, data->shape.size()); + ICHECK_EQ(2, kernel_size.size()); + ICHECK_EQ(2, dilation.size()); + ICHECK_EQ(2, padding.size()); + ICHECK_EQ(2, stride.size()); + + auto input_b = data->shape[0]; + auto input_c = data->shape[1]; + auto input_h = data->shape[2]; + auto input_w = data->shape[3]; + + auto kernel_h = tvm::cast(tvm::DataType::Int(32), kernel_size[0]); + auto kernel_w = tvm::cast(tvm::DataType::Int(32), kernel_size[1]); + + auto dilation_h = tvm::cast(tvm::DataType::Int(32), dilation[0]); + auto dilation_w = tvm::cast(tvm::DataType::Int(32), dilation[1]); + + auto padding_h = tvm::cast(tvm::DataType::Int(32), padding[0]); + auto padding_w = tvm::cast(tvm::DataType::Int(32), padding[1]); + + auto stride_h = tvm::cast(tvm::DataType::Int(32), stride[0]); + auto stride_w = tvm::cast(tvm::DataType::Int(32), stride[1]); + + auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; + auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; + + // Output size after padding + auto output_h = (input_h + 2 * padding_h - dilated_kernel_h)/stride_h + 1; + auto output_w = (input_w + 2 * padding_w - dilated_kernel_w)/stride_w + 1; + + // Result output size + auto N = input_b; + auto K = input_c * kernel_h * kernel_w; + auto L = output_h * output_w; + tvm::Array output_shape; + output_shape.push_back(N); + output_shape.push_back(K); + output_shape.push_back(L); + + auto pad_value = tvm::tir::make_const(data->dtype, 0); + + auto l = [&](tvm::Array args) { + tvm::tir::Var n = args[0]; + tvm::tir::Var k = args[1]; + tvm::tir::Var l = args[2]; + + tvm::Array indices; + tvm::Array condition; + + indices.push_back(n); // B, souce batch + + // source chanel s_c = k / kernel_h / kernel_w + tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w); + indices.push_back(s_c); // C, source channel + + // (k / kernel_w) % kernel_h + // stride_h * (l / output_w) + dilation_h * h_offset, + tvm::PrimExpr h_offset = indexmod(indexdiv(k, kernel_w), kernel_h); + tvm::PrimExpr s_h = stride_h * indexdiv(l, output_w) + dilation_h * h_offset - padding_h; + indices.push_back(s_h); // H, souce height + condition.push_back(s_h >= 0); + condition.push_back(s_h < input_h); + + // k % kernel_w; + // stride_w * (l % output_w) + dilation_w * w_offset, + tvm::PrimExpr w_offset = indexmod(k, kernel_w); + tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w; + indices.push_back(s_w); // W, source width + condition.push_back(s_w >= 0); + condition.push_back(s_w < input_w); + + return tvm::if_then_else( + foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, const_true(1), condition), + data(indices), pad_value); + }; + + return tvm::te::compute(output_shape, l, name, tag); +} + } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a8816ca607cc..27677640bdb3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2338,6 +2338,20 @@ def grid_sampler(self, inputs, input_types): grid = _op.transform.transpose(grid, axes=[0, 3, 1, 2]) return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW") + def im2col(self, inputs, input_types): + # torch F.unfold set kerenl_size, dilation, padding, stride as pairs before calling im2col + # but it brokern TVM "if condition expression", so please USE torch._C._nn.im2col instead + # of F.unfold and make sure giving paired parameters. Please reference test_forward_im2col + # in file tests/python/frontend/pytorch/test_forward.py. + + data = inputs[0] + kernel_size = inputs[1] + dilation = inputs[2] + padding = inputs[3] + stride = inputs[4] + + return _op.nn.im2col(data, kernel_size, dilation, padding, stride) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2555,6 +2569,7 @@ def create_convert_map(self): "aten::nll_loss2d": self.nll_loss, "aten::flip": self.flip, "aten::grid_sampler": self.grid_sampler, + "aten::im2col": self.im2col, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 753a17605667..7b384db491f6 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1308,3 +1308,6 @@ def dilate_shape_func(attrs, inputs, _): reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) + + +reg.register_broadcast_schedule("nn.im2col") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4c94102275bb..fac212d85cd1 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3646,3 +3646,39 @@ def batch_to_space_nd(data, block_shape, crops): """ return _make.batch_to_space_nd(data, block_shape, crops) + + +def im2col(data, kernel_size, dilation, padding, stride): + r"""im2col. + + This operator convert 4-D NCHW data with im2col for PyTorch unfold. + + Please reference + `https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold` + + Parameters + ---------- + data : tvm.relay.Expr + The input data is 4-D NCHW format. + + kernel_size : Tuple[int] + Specified the slide window height and width. + + dilation : Tuple[int] + Specifies the dilation rate for slide window. + + padding : Tuple[int] + Specifies the padding size with top/bottom and left/right. + + strides : Tuple[int] + Specifies the strides for slide window. + + Returns + ------- + result : tvm.relay.Expr + The computed result, 3-D NKL format, K is "C * kernel_height * kernel_width". + + """ + + return _make.im2col(data, kernel_size, dilation, padding, stride) + diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 365873d2fd51..6bb5fbc0baaa 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -272,5 +272,90 @@ RELAY_REGISTER_OP("nn.mirror_pad") .add_type_rel("MirrorPad", MirrorPadRel) .set_attr("TOpPattern", kInjective); + +Array Im2colCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + ICHECK(param != nullptr); + + return Array{topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)}; +} + + +bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [input, output] + ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; + + const auto* input = types[0].as(); + if (input == nullptr) + return false; + + if (input->shape.size() != 4) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Im2lossRel: input data should be 4 dimensions, NxCxHxW."); + return false; + } + + const Im2colAttrs* param = attrs.as(); + if (param == nullptr) + return false; + + // Calculate outout shape + auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); // tvm::PrimExpr + auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]); + auto dilation_h = tvm::cast(tvm::DataType::Int(32), param->dilation[0]); + auto dilation_w = tvm::cast(tvm::DataType::Int(32), param->dilation[1]); + auto padding_h = tvm::cast(tvm::DataType::Int(32), param->padding[0]); + auto padding_w = tvm::cast(tvm::DataType::Int(32), param->padding[1]); + auto stride_h = tvm::cast(tvm::DataType::Int(32), param->stride[0]); + auto stride_w = tvm::cast(tvm::DataType::Int(32), param->stride[1]); + auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; + auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; + // Output size after padding + auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h)/stride_h + 1; + auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w)/stride_w + 1; + + tvm::Array output_shape; + output_shape.push_back(input->shape[0]); // N + output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K + output_shape.push_back(output_h * output_w); // L + + // assign output type + reporter->Assign(types[1], TensorType(output_shape, input->dtype)); + + return true; +} + +// Handler to create a call to the im2col op used by front-end FFI +Expr MakeIm2col(Expr data, Array kernel_size, Array dilation, + Array padding, Array stride) { + auto attrs = make_object(); + + attrs->kernel_size = std::move(kernel_size); + attrs->dilation = std::move(dilation); + attrs->padding = std::move(padding); + attrs->stride = std::move(stride); + + static const Op& op = Op::Get("nn.im2col"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_NODE_TYPE(Im2colAttrs); +TVM_REGISTER_GLOBAL("relay.op.nn._make.im2col") + .set_body_typed(MakeIm2col); + +RELAY_REGISTER_OP("nn.im2col") + .describe(R"code(Im2col for 4-D NCHW tensor. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("Im2col", Im2colRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("FTVMCompute", Im2colCompute); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e1546f745d57..be975f5f723e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3912,6 +3912,44 @@ def forward(self, x): verify_model(Flip(axis=-1), input_data=input) +@tvm.testing.uses_gpu +def test_forward_im2col(): + torch.set_grad_enabled(False) + + class Im2col3x3(Module): + def __init__(self): + super(Im2col3x3, self).__init__() + + def forward(self, x): + # *********************************************************************************** + # + # !!! DO NOT USE !!! + # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1) + # for it broken TVM "if conditional expression" in torch script mode + # + # *********************************************************************************** + + return torch._C._nn.im2col(x, (3, 3), (1,1), (1,1), (1,1)) + + class Im2col5x5(Module): + def __init__(self): + super(Im2col5x5, self).__init__() + + def forward(self, x): + # *********************************************************************************** + # + # !!! DO NOT USE !!! + # F.unfold(x, kernel_size=5, dilation=1, padding=1, stride=2) + # for it broken TVM "if conditional expression" in torch script mode + # + # *********************************************************************************** + + return torch._C._nn.im2col(x, (5,5), (1,1), (1,1), (2,2)) + + input = torch.randn(2, 3, 32, 32) + verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) + + @tvm.testing.uses_gpu def test_forward_grid_sampler(): torch.set_grad_enabled(False) From bc8e1017aa92c530eb75c0930ebbdaaa8cb6254c Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sat, 10 Jul 2021 23:28:20 +0800 Subject: [PATCH 04/22] fix source format with lint. --- include/tvm/topi/nn.h | 7 ++++--- src/relay/op/nn/pad.cc | 15 ++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 95d56d514830..2a12a008ff09 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -729,7 +729,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T name="im2col_data", ) */ -inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, +inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, const tvm::Array& kernel_size, const tvm::Array& dilation, const tvm::Array& padding, @@ -802,13 +802,14 @@ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, // k % kernel_w; // stride_w * (l % output_w) + dilation_w * w_offset, tvm::PrimExpr w_offset = indexmod(k, kernel_w); - tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w; + tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w; indices.push_back(s_w); // W, source width condition.push_back(s_w >= 0); condition.push_back(s_w < input_w); return tvm::if_then_else( - foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, const_true(1), condition), + foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, + const_true(1), condition), data(indices), pad_value); }; diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 6bb5fbc0baaa..450cce1da60f 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -278,7 +278,8 @@ Array Im2colCompute(const Attrs& attrs, const Array& inp const auto* param = attrs.as(); ICHECK(param != nullptr); - return Array{topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)}; + return Array{topi::im2col(inputs[0], param->kernel_size, + param->dilation, param->padding, param->stride)}; } @@ -301,8 +302,8 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, if (param == nullptr) return false; - // Calculate outout shape - auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); // tvm::PrimExpr + // Calculate output shape + auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]); auto dilation_h = tvm::cast(tvm::DataType::Int(32), param->dilation[0]); auto dilation_w = tvm::cast(tvm::DataType::Int(32), param->dilation[1]); @@ -317,9 +318,9 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w)/stride_w + 1; tvm::Array output_shape; - output_shape.push_back(input->shape[0]); // N - output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K - output_shape.push_back(output_h * output_w); // L + output_shape.push_back(input->shape[0]); // N + output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K + output_shape.push_back(output_h * output_w); // L // assign output type reporter->Assign(types[1], TensorType(output_shape, input->dtype)); @@ -328,7 +329,7 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, } // Handler to create a call to the im2col op used by front-end FFI -Expr MakeIm2col(Expr data, Array kernel_size, Array dilation, +Expr MakeIm2col(Expr data, Array kernel_size, Array dilation, Array padding, Array stride) { auto attrs = make_object(); From 23df38b5991f7701ef1d6a8dd50a80db719c5451 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sun, 11 Jul 2021 00:43:01 +0800 Subject: [PATCH 05/22] Fix source format. --- include/tvm/topi/nn.h | 4 ++-- src/relay/op/nn/pad.cc | 33 ++++++++++++++------------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 2a12a008ff09..163f60a6478d 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -728,14 +728,14 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T ], name="im2col_data", ) - */ +*/ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, const tvm::Array& kernel_size, const tvm::Array& dilation, const tvm::Array& padding, const tvm::Array& stride, std::string name = "T_im2col", - std::string tag = kElementWise) { // ElementWise + std::string tag = kElementWise) { ICHECK_EQ(4, data->shape.size()); ICHECK_EQ(2, kernel_size.size()); ICHECK_EQ(2, dilation.size()); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 450cce1da60f..18a2cc4295c9 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -272,35 +272,31 @@ RELAY_REGISTER_OP("nn.mirror_pad") .add_type_rel("MirrorPad", MirrorPadRel) .set_attr("TOpPattern", kInjective); - Array Im2colCompute(const Attrs& attrs, const Array& inputs, - const Type& out_type) { + const Type& out_type) { const auto* param = attrs.as(); ICHECK(param != nullptr); - return Array{topi::im2col(inputs[0], param->kernel_size, - param->dilation, param->padding, param->stride)}; + return Array{ + topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)}; } - bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { // types: [input, output] ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; const auto* input = types[0].as(); - if (input == nullptr) - return false; + if (input == nullptr) return false; if (input->shape.size() != 4) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Im2lossRel: input data should be 4 dimensions, NxCxHxW."); + << "Im2lossRel: input data should be 4 dimensions, NxCxHxW."); return false; } const Im2colAttrs* param = attrs.as(); - if (param == nullptr) - return false; + if (param == nullptr) return false; // Calculate output shape auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); @@ -314,13 +310,13 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; // Output size after padding - auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h)/stride_h + 1; - auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w)/stride_w + 1; + auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h) / stride_h + 1; + auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w) / stride_w + 1; tvm::Array output_shape; - output_shape.push_back(input->shape[0]); // N - output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K - output_shape.push_back(output_h * output_w); // L + output_shape.push_back(input->shape[0]); // N + output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K + output_shape.push_back(output_h * output_w); // L // assign output type reporter->Assign(types[1], TensorType(output_shape, input->dtype)); @@ -330,7 +326,7 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, // Handler to create a call to the im2col op used by front-end FFI Expr MakeIm2col(Expr data, Array kernel_size, Array dilation, - Array padding, Array stride) { + Array padding, Array stride) { auto attrs = make_object(); attrs->kernel_size = std::move(kernel_size); @@ -343,8 +339,7 @@ Expr MakeIm2col(Expr data, Array kernel_size, Array dilati } TVM_REGISTER_NODE_TYPE(Im2colAttrs); -TVM_REGISTER_GLOBAL("relay.op.nn._make.im2col") - .set_body_typed(MakeIm2col); +TVM_REGISTER_GLOBAL("relay.op.nn._make.im2col").set_body_typed(MakeIm2col); RELAY_REGISTER_OP("nn.im2col") .describe(R"code(Im2col for 4-D NCHW tensor. From 1c64c44627997f892463962fba1149b7778c85d5 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sun, 11 Jul 2021 01:15:28 +0800 Subject: [PATCH 06/22] Fix source format. --- include/tvm/topi/nn.h | 26 ++++++++++++-------------- python/tvm/relay/frontend/pytorch.py | 11 ++++++----- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 163f60a6478d..8a53dbbc67ec 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -730,12 +730,11 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T ) */ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, - const tvm::Array& kernel_size, - const tvm::Array& dilation, - const tvm::Array& padding, - const tvm::Array& stride, - std::string name = "T_im2col", - std::string tag = kElementWise) { + const tvm::Array& kernel_size, + const tvm::Array& dilation, + const tvm::Array& padding, + const tvm::Array& stride, + std::string name = "T_im2col", std::string tag = kElementWise) { ICHECK_EQ(4, data->shape.size()); ICHECK_EQ(2, kernel_size.size()); ICHECK_EQ(2, dilation.size()); @@ -763,8 +762,8 @@ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; // Output size after padding - auto output_h = (input_h + 2 * padding_h - dilated_kernel_h)/stride_h + 1; - auto output_w = (input_w + 2 * padding_w - dilated_kernel_w)/stride_w + 1; + auto output_h = (input_h + 2 * padding_h - dilated_kernel_h) / stride_h + 1; + auto output_w = (input_w + 2 * padding_w - dilated_kernel_w) / stride_w + 1; // Result output size auto N = input_b; @@ -785,17 +784,17 @@ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, tvm::Array indices; tvm::Array condition; - indices.push_back(n); // B, souce batch + indices.push_back(n); // B, souce batch // source chanel s_c = k / kernel_h / kernel_w tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w); - indices.push_back(s_c); // C, source channel + indices.push_back(s_c); // C, source channel // (k / kernel_w) % kernel_h // stride_h * (l / output_w) + dilation_h * h_offset, tvm::PrimExpr h_offset = indexmod(indexdiv(k, kernel_w), kernel_h); tvm::PrimExpr s_h = stride_h * indexdiv(l, output_w) + dilation_h * h_offset - padding_h; - indices.push_back(s_h); // H, souce height + indices.push_back(s_h); // H, souce height condition.push_back(s_h >= 0); condition.push_back(s_h < input_h); @@ -803,19 +802,18 @@ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, // stride_w * (l % output_w) + dilation_w * w_offset, tvm::PrimExpr w_offset = indexmod(k, kernel_w); tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w; - indices.push_back(s_w); // W, source width + indices.push_back(s_w); // W, source width condition.push_back(s_w >= 0); condition.push_back(s_w < input_w); return tvm::if_then_else( foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, - const_true(1), condition), + const_true(1), condition), data(indices), pad_value); }; return tvm::te::compute(output_shape, l, name, tag); } - } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 27677640bdb3..2f7d0dcabbe8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2339,11 +2339,12 @@ def grid_sampler(self, inputs, input_types): return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW") def im2col(self, inputs, input_types): - # torch F.unfold set kerenl_size, dilation, padding, stride as pairs before calling im2col - # but it brokern TVM "if condition expression", so please USE torch._C._nn.im2col instead - # of F.unfold and make sure giving paired parameters. Please reference test_forward_im2col - # in file tests/python/frontend/pytorch/test_forward.py. - + r""" + Torch F.unfold set kerenl_size, dilation, padding, stride as pairs before calling im2col + but it brokern TVM "if condition expression", so please USE torch._C._nn.im2col instead + of F.unfold and make sure giving paired parameters. Please reference test_forward_im2col + in file tests/python/frontend/pytorch/test_forward.py. + """ data = inputs[0] kernel_size = inputs[1] dilation = inputs[2] From 6514f74fc41948f96863eadd3d0a24de1cea7044 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sun, 11 Jul 2021 01:32:55 +0800 Subject: [PATCH 07/22] fix source format. --- python/tvm/relay/op/nn/nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index fac212d85cd1..b226b4c26daa 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3681,4 +3681,3 @@ def im2col(data, kernel_size, dilation, padding, stride): """ return _make.im2col(data, kernel_size, dilation, padding, stride) - From bc02b4cdcb9f7f1c9df05cda8f8724a07ec53cde Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sun, 11 Jul 2021 01:54:18 +0800 Subject: [PATCH 08/22] format include/tvm/relay/attrs/nn.h --- include/tvm/relay/attrs/nn.h | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 6ccbce97e641..a91285a96339 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1471,18 +1471,10 @@ struct Im2colAttrs : public tvm::AttrsNode { Array stride; TVM_DECLARE_ATTRS(Im2colAttrs, "relay.attrs.Im2colAttrs") { - TVM_ATTR_FIELD(kernel_size) - .set_default(Array({3, 3})) - .describe("The kernel size."); - TVM_ATTR_FIELD(dilation) - .set_default(Array({1, 1})) - .describe("The dilation size."); - TVM_ATTR_FIELD(padding) - .set_default(Array({1, 1})) - .describe("The padding size."); - TVM_ATTR_FIELD(stride) - .set_default(Array({1, 1})) - .describe("The strides."); + TVM_ATTR_FIELD(kernel_size).set_default(Array({3, 3})).describe("The kernel size."); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})).describe("The dilation size."); + TVM_ATTR_FIELD(padding).set_default(Array({1, 1})).describe("The padding size."); + TVM_ATTR_FIELD(stride).set_default(Array({1, 1})).describe("The strides."); } }; // struct Im2colAttrs From d086e289103f4ba9ef343fd1c6aad1decc70360f Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sun, 11 Jul 2021 02:10:19 +0800 Subject: [PATCH 09/22] format tests/python/frontend/pytorch/test_forward.py --- tests/python/frontend/pytorch/test_forward.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index be975f5f723e..84890af48d36 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3929,7 +3929,7 @@ def forward(self, x): # # *********************************************************************************** - return torch._C._nn.im2col(x, (3, 3), (1,1), (1,1), (1,1)) + return torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1)) class Im2col5x5(Module): def __init__(self): @@ -3944,7 +3944,7 @@ def forward(self, x): # # *********************************************************************************** - return torch._C._nn.im2col(x, (5,5), (1,1), (1,1), (2,2)) + return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2)) input = torch.randn(2, 3, 32, 32) verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) @@ -3974,7 +3974,7 @@ def forward(self, input): # Torch grid_sample default: mode='bilinear', padding_mode='zeros', align_corners=False # tvm seems align corners as True - + # *********************************************************************************** # # !!! DO NOT USE !!! From 303c70a1e48022790c8bc41dd2f859b01d160e27 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Mon, 12 Jul 2021 01:33:36 +0800 Subject: [PATCH 10/22] Fix bug for aten::Float --- python/tvm/relay/frontend/pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2f7d0dcabbe8..9baf97d6b38c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1875,7 +1875,9 @@ def Bool(self, inputs, input_types): def Float(self, inputs, input_types): assert len(inputs) == 1 - return _op.cast(inputs[0], "float32") + if isinstance(inputs[0], _expr.Expr): + return inputs[0] + return float(inputs[0]) def bitwise_not(self, inputs, input_types): data = inputs[0] From 8798175110d5a978bd71a71b160894b5c46f9b71 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Mon, 12 Jul 2021 01:34:44 +0800 Subject: [PATCH 11/22] Test aten::Float for bug fixed --- tests/python/frontend/pytorch/test_forward.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 84890af48d36..e9684c6eb292 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3992,6 +3992,26 @@ def forward(self, input): verify_script_model(model.eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) +@tvm.testing.uses_gpu +def test_forward_float(): + torch.set_grad_enabled(False) + + def convert_i(i: int) -> float: + return float(i) + + class FloatModel(Module): + def __init__(self): + super(FloatModel, self).__init__() + + def forward(self, x): + f = convert_i(10) + return f * x + + model = FloatModel() + + verify_script_model(model.eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4136,6 +4156,7 @@ def forward(self, input): test_forward_nll_loss() test_forward_flip() test_forward_grid_sampler() + test_forward_float() # Model tests test_resnet18() From 7a1e24e4a3e2ed24e2109753bf203c2a84e68dda Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Mon, 12 Jul 2021 12:06:07 +0800 Subject: [PATCH 12/22] Add test_forward_im2col --- tests/python/frontend/pytorch/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e9684c6eb292..fa26623d3bb2 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4156,6 +4156,7 @@ def forward(self, x): test_forward_nll_loss() test_forward_flip() test_forward_grid_sampler() + test_forward_im2col() test_forward_float() # Model tests From f25c34be2890201b9e93c6059108f32479c3f0fd Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Mon, 12 Jul 2021 16:18:59 +0800 Subject: [PATCH 13/22] Add test_im2col for 3x3 kernel --- tests/python/frontend/pytorch/test_forward.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index fa26623d3bb2..211709a81212 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3946,7 +3946,10 @@ def forward(self, x): return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2)) + model = Im2col3x3() input = torch.randn(2, 3, 32, 32) + verify_model(model, input_data=input) + verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) From c41c30eff631c8456339d0c0af20cac03a732854 Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Mon, 12 Jul 2021 17:17:00 +0800 Subject: [PATCH 14/22] Add more checks for parameters. --- src/relay/op/nn/pad.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 18a2cc4295c9..674efbc78fea 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -298,6 +298,11 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, const Im2colAttrs* param = attrs.as(); if (param == nullptr) return false; + ICHECK_EQ(param->kernel_size.size(), 2) << "Expects two parameters for kernel height and width"; + ICHECK_EQ(param->dilation.size(), 2) << "Expects two parameters for dilation height and width"; + ICHECK_EQ(param->padding.size(), 2) << "Expects two parameters for padding height and width"; + ICHECK_EQ(param->stride.size(), 2) << "Expects two parameters for stride height and width"; + // Calculate output shape auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]); From c9704d9846a7b09f529b79fe7728d8dfe027c0c6 Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Mon, 12 Jul 2021 17:18:16 +0800 Subject: [PATCH 15/22] Delete white space in front of line. --- include/tvm/topi/nn.h | 72 +++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 8a53dbbc67ec..63903aab9d2d 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -692,42 +692,42 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T } /*! - * \brief Creates an operation that performs im2col with an NCHW-layout - * - * \param data The 4-D input tensor - * \param kernel_size A static tuple for kernel size, such as (3,3) - * \param dilation A static tuple for dilation, default is (1,1) - * \param padding A static tuple for padding, padding value is zero - * \param stride A static tuple for strides, default is (1,1) - * \param name The name of the operation - * \param tag The tag to mark the operation - * - * \return A Tensor whose op member is the im2col operation (NCHW layout) - - Pseudo code: - input_b, input_c, input_h, input_w = data_shape - dilation_h, dilation_w = dilation - padding_h, padding_w = padding - - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - - output_h = (input_h + 2 * padding_h - dilated_kernel_h)//stride_h + 1 - output_w = (input_w + 2 * padding_w - dilated_kernel_w)//stride_w + 1 - - h_offset = (k // input_w) % kernel_h - w_offset = k % kernel_w - - im2col_data = te.compute( - (N, K, L), - lambda n, k, l: data[ - n, - k / kernel_h / kernel_w, - stride_h * (l / output_w) + dilation_h * h_offset - padding_h, - stride_w * (l % output_w) + dilation_w * w_offset - padding_w, - ], - name="im2col_data", - ) +* \brief Creates an operation that performs im2col with an NCHW-layout +* +* \param data The 4-D input tensor +* \param kernel_size A static tuple for kernel size, such as (3,3) +* \param dilation A static tuple for dilation, default is (1,1) +* \param padding A static tuple for padding, padding value is zero +* \param stride A static tuple for strides, default is (1,1) +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the im2col operation (NCHW layout) + +Pseudo code: + input_b, input_c, input_h, input_w = data_shape + dilation_h, dilation_w = dilation + padding_h, padding_w = padding + + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + + output_h = (input_h + 2 * padding_h - dilated_kernel_h)//stride_h + 1 + output_w = (input_w + 2 * padding_w - dilated_kernel_w)//stride_w + 1 + + h_offset = (k // input_w) % kernel_h + w_offset = k % kernel_w + + im2col_data = te.compute( + (N, K, L), + lambda n, k, l: data[ + n, + k / kernel_h / kernel_w, + stride_h * (l / output_w) + dilation_h * h_offset - padding_h, + stride_w * (l % output_w) + dilation_w * w_offset - padding_w, + ], + name="im2col_data", + ) */ inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, const tvm::Array& kernel_size, From fbf98d2f1ba5f38e1f7e1dfa237ac15927654ea5 Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Wed, 14 Jul 2021 14:38:48 +0800 Subject: [PATCH 16/22] Simple test for im2col via removing script mode. --- tests/python/frontend/pytorch/test_forward.py | 38 +++++++------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 211709a81212..a6e916707c9b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2153,7 +2153,6 @@ def _get_default_vm_targets(): def verify_script_model(pt_model, ishapes, targets, idtype=None): script_module = torch.jit.script(pt_model) - verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets) @@ -3916,9 +3915,13 @@ def forward(self, x): def test_forward_im2col(): torch.set_grad_enabled(False) - class Im2col3x3(Module): - def __init__(self): - super(Im2col3x3, self).__init__() + class Im2col(Module): + def __init__(self, kernel_size, dilation, padding, stride): + super(Im2col, self).__init__() + self.kernel_size = (kernel_size, kernel_size) + self.dilation = (dilation, dilation) + self.padding = (padding, padding) + self.stride = (stride, stride) def forward(self, x): # *********************************************************************************** @@ -3928,29 +3931,14 @@ def forward(self, x): # for it broken TVM "if conditional expression" in torch script mode # # *********************************************************************************** + return torch._C._nn.im2col( + x, self.kernel_size, self.dilation, self.padding, self.stride + ) - return torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1)) - - class Im2col5x5(Module): - def __init__(self): - super(Im2col5x5, self).__init__() - - def forward(self, x): - # *********************************************************************************** - # - # !!! DO NOT USE !!! - # F.unfold(x, kernel_size=5, dilation=1, padding=1, stride=2) - # for it broken TVM "if conditional expression" in torch script mode - # - # *********************************************************************************** - - return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2)) - - model = Im2col3x3() input = torch.randn(2, 3, 32, 32) - verify_model(model, input_data=input) - - verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) + verify_model(Im2col(5, 1, 1, 2), input_data=input) + verify_model(Im2col(3, 1, 2, 1), input_data=input) + verify_model(Im2col(5, 1, 2, 2), input_data=input) @tvm.testing.uses_gpu From 65bbf3028401373a3b401ba72e4e595fb99d600d Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Thu, 15 Jul 2021 16:13:01 +0800 Subject: [PATCH 17/22] Remove init constructor from FloatModel. --- tests/python/frontend/pytorch/test_forward.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a6e916707c9b..0fa4b8067e10 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3991,9 +3991,6 @@ def convert_i(i: int) -> float: return float(i) class FloatModel(Module): - def __init__(self): - super(FloatModel, self).__init__() - def forward(self, x): f = convert_i(10) return f * x From 43c5d01caba8986f1c5a999b7c4688bfd7183e70 Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Thu, 15 Jul 2021 16:45:41 +0800 Subject: [PATCH 18/22] Dropout script --- python/tvm/relay/frontend/pytorch.py | 4 +-- tests/python/frontend/pytorch/test_forward.py | 34 ------------------- 2 files changed, 1 insertion(+), 37 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9baf97d6b38c..2f7d0dcabbe8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1875,9 +1875,7 @@ def Bool(self, inputs, input_types): def Float(self, inputs, input_types): assert len(inputs) == 1 - if isinstance(inputs[0], _expr.Expr): - return inputs[0] - return float(inputs[0]) + return _op.cast(inputs[0], "float32") def bitwise_not(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0fa4b8067e10..604bd3a53e0a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3924,13 +3924,6 @@ def __init__(self, kernel_size, dilation, padding, stride): self.stride = (stride, stride) def forward(self, x): - # *********************************************************************************** - # - # !!! DO NOT USE !!! - # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1) - # for it broken TVM "if conditional expression" in torch script mode - # - # *********************************************************************************** return torch._C._nn.im2col( x, self.kernel_size, self.dilation, self.padding, self.stride ) @@ -3966,38 +3959,11 @@ def forward(self, input): # Torch grid_sample default: mode='bilinear', padding_mode='zeros', align_corners=False # tvm seems align corners as True - # *********************************************************************************** - # - # !!! DO NOT USE !!! - # F.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True) - # for it broken TVM "if conditional expression" in torch script mode - # - # *********************************************************************************** - return torch.grid_sampler(input, grid, 0, 0, True) model = GridSampler(16, 32) input = torch.randn(2, 3, 32, 32) - verify_model(model, input_data=input) - verify_script_model(model.eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) - - -@tvm.testing.uses_gpu -def test_forward_float(): - torch.set_grad_enabled(False) - - def convert_i(i: int) -> float: - return float(i) - - class FloatModel(Module): - def forward(self, x): - f = convert_i(10) - return f * x - - model = FloatModel() - - verify_script_model(model.eval(), [(2, 3, 32, 32)], _get_default_vm_targets()) if __name__ == "__main__": From 0387aeff9db646ce32fa5768d187cc4823ab0e6c Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Thu, 15 Jul 2021 21:19:50 +0800 Subject: [PATCH 19/22] Change im2col schudule from broadcast to injective, torch._C_.nn.im2col to F.unfold. --- python/tvm/relay/op/nn/_nn.py | 2 +- src/relay/op/nn/pad.cc | 2 +- tests/python/frontend/pytorch/test_forward.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 7b384db491f6..ac95ab45c08c 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1310,4 +1310,4 @@ def dilate_shape_func(attrs, inputs, _): reg.register_shape_func("nn.relu", False, elemwise_shape_func) -reg.register_broadcast_schedule("nn.im2col") +reg.register_injective_schedule("nn.im2col") diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 674efbc78fea..e4723dad475b 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -355,7 +355,7 @@ RELAY_REGISTER_OP("nn.im2col") .add_argument("data", "Tensor", "Input data.") .set_support_level(2) .add_type_rel("Im2col", Im2colRel) - .set_attr("TOpPattern", kOpaque) + .set_attr("TOpPattern", kInjective) .set_attr("FTVMCompute", Im2colCompute); } // namespace relay diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 604bd3a53e0a..722d897b479a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3924,7 +3924,7 @@ def __init__(self, kernel_size, dilation, padding, stride): self.stride = (stride, stride) def forward(self, x): - return torch._C._nn.im2col( + return F.unfold( x, self.kernel_size, self.dilation, self.padding, self.stride ) From 7ca82a967ab459aa3f5230d0803ddaba561a41cd Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Thu, 15 Jul 2021 23:31:40 +0800 Subject: [PATCH 20/22] Reformat source code --- tests/python/frontend/pytorch/test_forward.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 722d897b479a..a0cb90bea794 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3924,9 +3924,7 @@ def __init__(self, kernel_size, dilation, padding, stride): self.stride = (stride, stride) def forward(self, x): - return F.unfold( - x, self.kernel_size, self.dilation, self.padding, self.stride - ) + return F.unfold(x, self.kernel_size, self.dilation, self.padding, self.stride) input = torch.randn(2, 3, 32, 32) verify_model(Im2col(5, 1, 1, 2), input_data=input) From 5667b941ec1875d8d1876efc0d17d08ee0b43afa Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Fri, 16 Jul 2021 00:16:06 +0800 Subject: [PATCH 21/22] Fix typo mistakes. --- src/relay/op/nn/pad.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index e4723dad475b..2f5e61cb6d44 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -291,7 +291,7 @@ bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, if (input->shape.size() != 4) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Im2lossRel: input data should be 4 dimensions, NxCxHxW."); + << "Im2colRel: input data should be 4 dimensions, NxCxHxW."); return false; } From 97a5df6740aab5822ae04a932782b3272cbfbc59 Mon Sep 17 00:00:00 2001 From: delldu <18588220928@163.com> Date: Fri, 16 Jul 2021 17:00:59 +0800 Subject: [PATCH 22/22] Remove comments. --- python/tvm/relay/frontend/pytorch.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2f7d0dcabbe8..a7c8db67c533 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2339,12 +2339,6 @@ def grid_sampler(self, inputs, input_types): return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW") def im2col(self, inputs, input_types): - r""" - Torch F.unfold set kerenl_size, dilation, padding, stride as pairs before calling im2col - but it brokern TVM "if condition expression", so please USE torch._C._nn.im2col instead - of F.unfold and make sure giving paired parameters. Please reference test_forward_im2col - in file tests/python/frontend/pytorch/test_forward.py. - """ data = inputs[0] kernel_size = inputs[1] dilation = inputs[2]