diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 3c7574562676..a91285a96339 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1463,6 +1463,21 @@ struct NLLLossAttrs : public tvm::AttrsNode { } }; // struct NLLLossAttrs +/*! \brief Attributes used in Im2col operator */ +struct Im2colAttrs : public tvm::AttrsNode { + Array kernel_size; + Array dilation; + Array padding; + Array stride; + + TVM_DECLARE_ATTRS(Im2colAttrs, "relay.attrs.Im2colAttrs") { + TVM_ATTR_FIELD(kernel_size).set_default(Array({3, 3})).describe("The kernel size."); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})).describe("The dilation size."); + TVM_ATTR_FIELD(padding).set_default(Array({1, 1})).describe("The padding size."); + TVM_ATTR_FIELD(stride).set_default(Array({1, 1})).describe("The strides."); + } +}; // struct Im2colAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 90c1c09a070b..63903aab9d2d 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -690,6 +690,130 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T return T; } } + +/*! +* \brief Creates an operation that performs im2col with an NCHW-layout +* +* \param data The 4-D input tensor +* \param kernel_size A static tuple for kernel size, such as (3,3) +* \param dilation A static tuple for dilation, default is (1,1) +* \param padding A static tuple for padding, padding value is zero +* \param stride A static tuple for strides, default is (1,1) +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the im2col operation (NCHW layout) + +Pseudo code: + input_b, input_c, input_h, input_w = data_shape + dilation_h, dilation_w = dilation + padding_h, padding_w = padding + + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + + output_h = (input_h + 2 * padding_h - dilated_kernel_h)//stride_h + 1 + output_w = (input_w + 2 * padding_w - dilated_kernel_w)//stride_w + 1 + + h_offset = (k // input_w) % kernel_h + w_offset = k % kernel_w + + im2col_data = te.compute( + (N, K, L), + lambda n, k, l: data[ + n, + k / kernel_h / kernel_w, + stride_h * (l / output_w) + dilation_h * h_offset - padding_h, + stride_w * (l % output_w) + dilation_w * w_offset - padding_w, + ], + name="im2col_data", + ) +*/ +inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, + const tvm::Array& kernel_size, + const tvm::Array& dilation, + const tvm::Array& padding, + const tvm::Array& stride, + std::string name = "T_im2col", std::string tag = kElementWise) { + ICHECK_EQ(4, data->shape.size()); + ICHECK_EQ(2, kernel_size.size()); + ICHECK_EQ(2, dilation.size()); + ICHECK_EQ(2, padding.size()); + ICHECK_EQ(2, stride.size()); + + auto input_b = data->shape[0]; + auto input_c = data->shape[1]; + auto input_h = data->shape[2]; + auto input_w = data->shape[3]; + + auto kernel_h = tvm::cast(tvm::DataType::Int(32), kernel_size[0]); + auto kernel_w = tvm::cast(tvm::DataType::Int(32), kernel_size[1]); + + auto dilation_h = tvm::cast(tvm::DataType::Int(32), dilation[0]); + auto dilation_w = tvm::cast(tvm::DataType::Int(32), dilation[1]); + + auto padding_h = tvm::cast(tvm::DataType::Int(32), padding[0]); + auto padding_w = tvm::cast(tvm::DataType::Int(32), padding[1]); + + auto stride_h = tvm::cast(tvm::DataType::Int(32), stride[0]); + auto stride_w = tvm::cast(tvm::DataType::Int(32), stride[1]); + + auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; + auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; + + // Output size after padding + auto output_h = (input_h + 2 * padding_h - dilated_kernel_h) / stride_h + 1; + auto output_w = (input_w + 2 * padding_w - dilated_kernel_w) / stride_w + 1; + + // Result output size + auto N = input_b; + auto K = input_c * kernel_h * kernel_w; + auto L = output_h * output_w; + tvm::Array output_shape; + output_shape.push_back(N); + output_shape.push_back(K); + output_shape.push_back(L); + + auto pad_value = tvm::tir::make_const(data->dtype, 0); + + auto l = [&](tvm::Array args) { + tvm::tir::Var n = args[0]; + tvm::tir::Var k = args[1]; + tvm::tir::Var l = args[2]; + + tvm::Array indices; + tvm::Array condition; + + indices.push_back(n); // B, souce batch + + // source chanel s_c = k / kernel_h / kernel_w + tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w); + indices.push_back(s_c); // C, source channel + + // (k / kernel_w) % kernel_h + // stride_h * (l / output_w) + dilation_h * h_offset, + tvm::PrimExpr h_offset = indexmod(indexdiv(k, kernel_w), kernel_h); + tvm::PrimExpr s_h = stride_h * indexdiv(l, output_w) + dilation_h * h_offset - padding_h; + indices.push_back(s_h); // H, souce height + condition.push_back(s_h >= 0); + condition.push_back(s_h < input_h); + + // k % kernel_w; + // stride_w * (l % output_w) + dilation_w * w_offset, + tvm::PrimExpr w_offset = indexmod(k, kernel_w); + tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w; + indices.push_back(s_w); // W, source width + condition.push_back(s_w >= 0); + condition.push_back(s_w < input_w); + + return tvm::if_then_else( + foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, + const_true(1), condition), + data(indices), pad_value); + }; + + return tvm::te::compute(output_shape, l, name, tag); +} } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4c874672445b..a7c8db67c533 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2329,6 +2329,24 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def grid_sampler(self, inputs, input_types): + data = inputs[0] + grid = inputs[1] + + # Torch grid shape is like [batch, out_height, out_width, 2], but + # TVM grid is [batch, 2, out_height, out_width], so here grid need to be converted + grid = _op.transform.transpose(grid, axes=[0, 3, 1, 2]) + return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW") + + def im2col(self, inputs, input_types): + data = inputs[0] + kernel_size = inputs[1] + dilation = inputs[2] + padding = inputs[3] + stride = inputs[4] + + return _op.nn.im2col(data, kernel_size, dilation, padding, stride) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2545,6 +2563,8 @@ def create_convert_map(self): "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, "aten::flip": self.flip, + "aten::grid_sampler": self.grid_sampler, + "aten::im2col": self.im2col, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index b0b5a569f9e0..f8120f27dd8d 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1339,3 +1339,6 @@ def dilate_shape_func(attrs, inputs, _): reg.register_shape_func("nn.softmax", False, elemwise_shape_func) reg.register_shape_func("nn.fast_softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) + + +reg.register_injective_schedule("nn.im2col") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4c94102275bb..b226b4c26daa 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3646,3 +3646,38 @@ def batch_to_space_nd(data, block_shape, crops): """ return _make.batch_to_space_nd(data, block_shape, crops) + + +def im2col(data, kernel_size, dilation, padding, stride): + r"""im2col. + + This operator convert 4-D NCHW data with im2col for PyTorch unfold. + + Please reference + `https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold` + + Parameters + ---------- + data : tvm.relay.Expr + The input data is 4-D NCHW format. + + kernel_size : Tuple[int] + Specified the slide window height and width. + + dilation : Tuple[int] + Specifies the dilation rate for slide window. + + padding : Tuple[int] + Specifies the padding size with top/bottom and left/right. + + strides : Tuple[int] + Specifies the strides for slide window. + + Returns + ------- + result : tvm.relay.Expr + The computed result, 3-D NKL format, K is "C * kernel_height * kernel_width". + + """ + + return _make.im2col(data, kernel_size, dilation, padding, stride) diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 365873d2fd51..2f5e61cb6d44 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -272,5 +272,91 @@ RELAY_REGISTER_OP("nn.mirror_pad") .add_type_rel("MirrorPad", MirrorPadRel) .set_attr("TOpPattern", kInjective); +Array Im2colCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + ICHECK(param != nullptr); + + return Array{ + topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)}; +} + +bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [input, output] + ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; + + const auto* input = types[0].as(); + if (input == nullptr) return false; + + if (input->shape.size() != 4) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Im2colRel: input data should be 4 dimensions, NxCxHxW."); + return false; + } + + const Im2colAttrs* param = attrs.as(); + if (param == nullptr) return false; + + ICHECK_EQ(param->kernel_size.size(), 2) << "Expects two parameters for kernel height and width"; + ICHECK_EQ(param->dilation.size(), 2) << "Expects two parameters for dilation height and width"; + ICHECK_EQ(param->padding.size(), 2) << "Expects two parameters for padding height and width"; + ICHECK_EQ(param->stride.size(), 2) << "Expects two parameters for stride height and width"; + + // Calculate output shape + auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); + auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]); + auto dilation_h = tvm::cast(tvm::DataType::Int(32), param->dilation[0]); + auto dilation_w = tvm::cast(tvm::DataType::Int(32), param->dilation[1]); + auto padding_h = tvm::cast(tvm::DataType::Int(32), param->padding[0]); + auto padding_w = tvm::cast(tvm::DataType::Int(32), param->padding[1]); + auto stride_h = tvm::cast(tvm::DataType::Int(32), param->stride[0]); + auto stride_w = tvm::cast(tvm::DataType::Int(32), param->stride[1]); + auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; + auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; + // Output size after padding + auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h) / stride_h + 1; + auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w) / stride_w + 1; + + tvm::Array output_shape; + output_shape.push_back(input->shape[0]); // N + output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K + output_shape.push_back(output_h * output_w); // L + + // assign output type + reporter->Assign(types[1], TensorType(output_shape, input->dtype)); + + return true; +} + +// Handler to create a call to the im2col op used by front-end FFI +Expr MakeIm2col(Expr data, Array kernel_size, Array dilation, + Array padding, Array stride) { + auto attrs = make_object(); + + attrs->kernel_size = std::move(kernel_size); + attrs->dilation = std::move(dilation); + attrs->padding = std::move(padding); + attrs->stride = std::move(stride); + + static const Op& op = Op::Get("nn.im2col"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_NODE_TYPE(Im2colAttrs); +TVM_REGISTER_GLOBAL("relay.op.nn._make.im2col").set_body_typed(MakeIm2col); + +RELAY_REGISTER_OP("nn.im2col") + .describe(R"code(Im2col for 4-D NCHW tensor. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("Im2col", Im2colRel) + .set_attr("TOpPattern", kInjective) + .set_attr("FTVMCompute", Im2colCompute); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f76ea9a5d324..a0cb90bea794 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2153,7 +2153,6 @@ def _get_default_vm_targets(): def verify_script_model(pt_model, ishapes, targets, idtype=None): script_module = torch.jit.script(pt_model) - verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets) @@ -3912,6 +3911,59 @@ def forward(self, x): verify_model(Flip(axis=-1), input_data=input) +@tvm.testing.uses_gpu +def test_forward_im2col(): + torch.set_grad_enabled(False) + + class Im2col(Module): + def __init__(self, kernel_size, dilation, padding, stride): + super(Im2col, self).__init__() + self.kernel_size = (kernel_size, kernel_size) + self.dilation = (dilation, dilation) + self.padding = (padding, padding) + self.stride = (stride, stride) + + def forward(self, x): + return F.unfold(x, self.kernel_size, self.dilation, self.padding, self.stride) + + input = torch.randn(2, 3, 32, 32) + verify_model(Im2col(5, 1, 1, 2), input_data=input) + verify_model(Im2col(3, 1, 2, 1), input_data=input) + verify_model(Im2col(5, 1, 2, 2), input_data=input) + + +@tvm.testing.uses_gpu +def test_forward_grid_sampler(): + torch.set_grad_enabled(False) + + class GridSampler(Module): + def __init__(self, output_h, output_w): + super(GridSampler, self).__init__() + self.output_h = output_h + self.output_w = output_w + + # normalize to [-1.0, 1.0] + h = torch.arange(0, output_h) / (output_h - 1.0) * 2.0 - 1.0 + w = torch.arange(0, output_w) / (output_w - 1.0) * 2.0 - 1.0 + grid = torch.zeros(output_h, output_w, 2) + grid[:, :, 0] = w.unsqueeze(0).repeat(output_h, 1) + grid[:, :, 1] = h.unsqueeze(0).repeat(output_w, 1).transpose(0, 1) + self.grid = grid.unsqueeze(0) + + def forward(self, input): + batch = input.size(0) + grid = self.grid.repeat(batch, 1, 1, 1).to(input.device) + + # Torch grid_sample default: mode='bilinear', padding_mode='zeros', align_corners=False + # tvm seems align corners as True + + return torch.grid_sampler(input, grid, 0, 0, True) + + model = GridSampler(16, 32) + input = torch.randn(2, 3, 32, 32) + verify_model(model, input_data=input) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4055,6 +4107,9 @@ def forward(self, x): test_hard_sigmoid() test_forward_nll_loss() test_forward_flip() + test_forward_grid_sampler() + test_forward_im2col() + test_forward_float() # Model tests test_resnet18()