Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support aten::grid_sampler, aten::im2col #8443

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,21 @@ struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
}
}; // struct NLLLossAttrs

/*! \brief Attributes used in Im2col operator */
struct Im2colAttrs : public tvm::AttrsNode<Im2colAttrs> {
Array<IndexExpr> kernel_size;
Array<IndexExpr> dilation;
Array<IndexExpr> padding;
Array<IndexExpr> stride;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IndexExpr should be Integer. If you make this change you can remove all the casts.

Copy link
Contributor Author

@delldu delldu Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry let give a little explain.

  1. Im2Compute must following TVM RELAY_REGISTER_OP interface, so it is
Array<te::Tensor> Im2colCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
                                const Type& out_type) {
  const auto* param = attrs.as<Im2colAttrs>();
  ICHECK(param != nullptr);

  return Array<te::Tensor>{
      topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)};
}
  1. We have two choice, 1)parsing parameters in Im2colCompute, 2) parsing in topi::im2col, here we choice 2), that is in topi::im2col, there are no difference. 3) Change Im2colAttrs.
  2. Struct Conv2DAttrs is also using IndexExpr instead of Integer, so Im2colAttrs follow their style, in fact, kernel_size mean kernel_size_h/w, dilation means dilation_h/w etc, not pure Integer.

So RELAY_REGISTER_OP ==> Im2colCompute ==> im2col and Im2colAttrs, we almost have no choice.

Copy link
Member

@masahi masahi Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why you can't use Array<Integer>. There is no need to follow Conv2DAttrs, if all attributes in im2col ops are integer constant, there is no point using Array<IndexExpr> and it just makes your implementation unnecessarily complicated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see. will modify.

Copy link
Contributor Author

@delldu delldu Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @masahi , @wyc-ruiker ,

First all, thank you help us improve PR. Sounds there is one import issue left, I would give my understanding.

Second, please remember using IndexExpr = ::tvm::PrimExpr;

1. Conclusion
Changing Array<IndexExpr> to Array<Integer> in Im2colAttrs is not good.

2. Function Patch
inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, ...)
{
auto input_b = data->shape[0];
...
auto kernel_h = tvm::cast(tvm::DataType::Int(32), kernel_size[0]);
...
tvm::Arraytvm::PrimExpr output_shape;
output_shape.push_back(N);
...
auto pad_value = tvm::tir::make_const(data->dtype, 0);
auto l = [&](tvm::Arraytvm::tir::Var args) {
...
tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w);
indices.push_back(s_c); // C, source channel
...
return tvm::if_then_else(..., data(indices), pad_value);
};
return tvm::te::compute(output_shape, l, name, tag);
}

3. Cause analysis

  1. Input
    The shape of input data is Array<PrimExpr>;
    inline tvm::te::Tensor im2col(const tvm::te::Tensor& data
  2. Output
    output shape is also Array<PrimExpr>.
    TVM_DLL Tensor compute(Array<PrimExpr> shape, ...);
  3. Process
    3.1) Many expression in lambda function such as indices, are Array<PrimExpr>, NOT Array<Integer>.
    3.2) Some function parameters are PrimExpr.
    For an example: prototye of indexmod is:
TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());

If we give Integer parameters to indexdiv, only got compile fatal warning and errors ...

4. Test Result
We try to change Array<IndexExpr> to Array<Integer>, modified related source code,
fighting again and again, using many methods, finally got many and many building(compile) errors.

5. Inference
tvm::cast is good choice than as_const_int for parsing parameters.
The reason comes from their prototype

inline const *int64_t as_const_int(const PrimExpr& x);
TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span());

Following above results, that is why we can't use Array<Integer>.


TVM_DECLARE_ATTRS(Im2colAttrs, "relay.attrs.Im2colAttrs") {
TVM_ATTR_FIELD(kernel_size).set_default(Array<IndexExpr>({3, 3})).describe("The kernel size.");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})).describe("The dilation size.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({1, 1})).describe("The padding size.");
TVM_ATTR_FIELD(stride).set_default(Array<IndexExpr>({1, 1})).describe("The strides.");
}
}; // struct Im2colAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
124 changes: 124 additions & 0 deletions include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,130 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T
return T;
}
}

/*!
* \brief Creates an operation that performs im2col with an NCHW-layout
*
* \param data The 4-D input tensor
* \param kernel_size A static tuple for kernel size, such as (3,3)
* \param dilation A static tuple for dilation, default is (1,1)
* \param padding A static tuple for padding, padding value is zero
* \param stride A static tuple for strides, default is (1,1)
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the im2col operation (NCHW layout)

Pseudo code:
input_b, input_c, input_h, input_w = data_shape
dilation_h, dilation_w = dilation
padding_h, padding_w = padding

dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1

output_h = (input_h + 2 * padding_h - dilated_kernel_h)//stride_h + 1
output_w = (input_w + 2 * padding_w - dilated_kernel_w)//stride_w + 1

h_offset = (k // input_w) % kernel_h
w_offset = k % kernel_w

im2col_data = te.compute(
(N, K, L),
lambda n, k, l: data[
n,
k / kernel_h / kernel_w,
stride_h * (l / output_w) + dilation_h * h_offset - padding_h,
stride_w * (l % output_w) + dilation_w * w_offset - padding_w,
],
name="im2col_data",
)
*/
inline tvm::te::Tensor im2col(const tvm::te::Tensor& data,
const tvm::Array<tvm::PrimExpr>& kernel_size,
const tvm::Array<tvm::PrimExpr>& dilation,
const tvm::Array<tvm::PrimExpr>& padding,
const tvm::Array<tvm::PrimExpr>& stride,
std::string name = "T_im2col", std::string tag = kElementWise) {
ICHECK_EQ(4, data->shape.size());
ICHECK_EQ(2, kernel_size.size());
ICHECK_EQ(2, dilation.size());
ICHECK_EQ(2, padding.size());
ICHECK_EQ(2, stride.size());

auto input_b = data->shape[0];
auto input_c = data->shape[1];
auto input_h = data->shape[2];
auto input_w = data->shape[3];

auto kernel_h = tvm::cast(tvm::DataType::Int(32), kernel_size[0]);
auto kernel_w = tvm::cast(tvm::DataType::Int(32), kernel_size[1]);

auto dilation_h = tvm::cast(tvm::DataType::Int(32), dilation[0]);
auto dilation_w = tvm::cast(tvm::DataType::Int(32), dilation[1]);

auto padding_h = tvm::cast(tvm::DataType::Int(32), padding[0]);
auto padding_w = tvm::cast(tvm::DataType::Int(32), padding[1]);

auto stride_h = tvm::cast(tvm::DataType::Int(32), stride[0]);
auto stride_w = tvm::cast(tvm::DataType::Int(32), stride[1]);
Comment on lines +749 to +759
Copy link
Contributor

@wyc-ruiker wyc-ruiker Jul 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

@delldu delldu Jul 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explained as before. tvm::cast would be a better choice than tir::as_const_int if we run script mode, PyTorch script mode could generate complex expression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really support script mode. Please don't make your implementation complicated to support script mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copy from PadCompute.

Copy link
Contributor Author

@delldu delldu Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reference on same file about 155~245 lines.

tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array<tvm::PrimExpr>& pad_before,
                           tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
                           PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
                           std::string tag = kElementWise, std::string pad_mode = "constant",
                           const Array<PrimExpr>* dyn_output_shape = nullptr)

with same function tvm::cast to parsing parameters.
A little more, this is no business with script mode, just for parsing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@delldu Please drop support for script mode. Just because your model has control flow doesn't mean these ops also need to be scripted. Remember, everyone can use these op conversion. You can do trace and selectively script certain parts of your model using torch.jit._script_if_tracing. This is how MaskRCNN in torchvision is implemented, for example. See https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py#L454

Thanks. It is real good reference. High value !

Copy link
Member

@masahi masahi Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to copy old code. kernel_size etc should be constant integers, please remove all casts.


auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1;
auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1;

// Output size after padding
auto output_h = (input_h + 2 * padding_h - dilated_kernel_h) / stride_h + 1;
auto output_w = (input_w + 2 * padding_w - dilated_kernel_w) / stride_w + 1;

// Result output size
auto N = input_b;
auto K = input_c * kernel_h * kernel_w;
auto L = output_h * output_w;
tvm::Array<tvm::PrimExpr> output_shape;
output_shape.push_back(N);
output_shape.push_back(K);
output_shape.push_back(L);

auto pad_value = tvm::tir::make_const(data->dtype, 0);

auto l = [&](tvm::Array<tvm::tir::Var> args) {
tvm::tir::Var n = args[0];
tvm::tir::Var k = args[1];
tvm::tir::Var l = args[2];

tvm::Array<tvm::PrimExpr> indices;
tvm::Array<tvm::PrimExpr> condition;

indices.push_back(n); // B, souce batch

// source chanel s_c = k / kernel_h / kernel_w
tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w);
indices.push_back(s_c); // C, source channel

// (k / kernel_w) % kernel_h
// stride_h * (l / output_w) + dilation_h * h_offset,
tvm::PrimExpr h_offset = indexmod(indexdiv(k, kernel_w), kernel_h);
tvm::PrimExpr s_h = stride_h * indexdiv(l, output_w) + dilation_h * h_offset - padding_h;
indices.push_back(s_h); // H, souce height
condition.push_back(s_h >= 0);
condition.push_back(s_h < input_h);

// k % kernel_w;
// stride_w * (l % output_w) + dilation_w * w_offset,
tvm::PrimExpr w_offset = indexmod(k, kernel_w);
tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w;
indices.push_back(s_w); // W, source width
condition.push_back(s_w >= 0);
condition.push_back(s_w < input_w);

return tvm::if_then_else(
foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
const_true(1), condition),
data(indices), pad_value);
};

return tvm::te::compute(output_shape, l, name, tag);
}
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_NN_H_
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,6 +2329,30 @@ def flip(self, inputs, input_types):
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])

def grid_sampler(self, inputs, input_types):
data = inputs[0]
grid = inputs[1]

# Torch grid shape is like [batch, out_height, out_width, 2], but
# TVM grid is [batch, 2, out_height, out_width], so here grid need to be converted
grid = _op.transform.transpose(grid, axes=[0, 3, 1, 2])
return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW")

def im2col(self, inputs, input_types):
r"""
Torch F.unfold set kerenl_size, dilation, padding, stride as pairs before calling im2col
but it brokern TVM "if condition expression", so please USE torch._C._nn.im2col instead
of F.unfold and make sure giving paired parameters. Please reference test_forward_im2col
in file tests/python/frontend/pytorch/test_forward.py.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this sentences.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From [https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#unfold]:

def unfold(
    input: Tensor, kernel_size: BroadcastingList2[int],
    dilation: BroadcastingList2[int] = 1,
    padding: BroadcastingList2[int] = 0,
    stride: BroadcastingList2[int] = 1
) -> Tensor:
    if input.dim() == 4:
        msg = "{} must be int or 2-tuple for 4D input"
        assert_int_or_pair(kernel_size, "kernel_size", msg)
        assert_int_or_pair(dilation, "dilation", msg)
        assert_int_or_pair(padding, "padding", msg)
        assert_int_or_pair(stride, "stride", msg)

        return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
    else:
        raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim()))

You can find:

  1. _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)
  2. if input.dim() == 4: ... else raise ...
    First thing, we hope user call im2col with pairs parameters. Second, we hope user do not use F.unfold directly for it breaks "if return tensor1 else return tensor2" condition, otherwise, type check will failure.

Copy link
Contributor Author

@delldu delldu Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed usefulness comments in source code.

Copy link
Member

@masahi masahi Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, please fix the typo and grammar error, otherwise the comments do not make sense. I don't understand what is "if condition expression". Please revisit this problem assuming you don't need to support script mode. torch._C._nn.im2col is not supposed to be used by users, we should support the usage of torch.unfold.

"""
data = inputs[0]
kernel_size = inputs[1]
dilation = inputs[2]
padding = inputs[3]
stride = inputs[4]

return _op.nn.im2col(data, kernel_size, dilation, padding, stride)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2545,6 +2569,8 @@ def create_convert_map(self):
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::flip": self.flip,
"aten::grid_sampler": self.grid_sampler,
"aten::im2col": self.im2col,
}

def update_convert_map(self, custom_map):
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,3 +1308,6 @@ def dilate_shape_func(attrs, inputs, _):
reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
reg.register_shape_func("nn.relu", False, elemwise_shape_func)


reg.register_broadcast_schedule("nn.im2col")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be broadcast if you set TOpPattern to kOpaque.

Copy link
Contributor Author

@delldu delldu Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we change kOpaque to kInjective ? Frankly, I do not understand it. please teach me. Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it kInjective and use register_injective_schedule here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we have change kOpaque to kInjective, and test passed.

35 changes: 35 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3646,3 +3646,38 @@ def batch_to_space_nd(data, block_shape, crops):
"""

return _make.batch_to_space_nd(data, block_shape, crops)


def im2col(data, kernel_size, dilation, padding, stride):
r"""im2col.

This operator convert 4-D NCHW data with im2col for PyTorch unfold.

Please reference
`https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch specific stuff shouldn't be in this file. Please explain what this function does instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you click "[SOURCE]" link, you will find

    def forward(self, input: Tensor) -> Tensor:
        return F.unfold(input, self.kernel_size, self.dilation,
                        self.padding, self.stride)

torch.nn.Unfold forward will all F.unfold, F.unfold will call im2col. If you call F.unfold or nn.Unfold forward, from_pytorch will give you a exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed usefulness comments in source code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just remove references to PyTorch and describe what it does. This is not frontend code.

Parameters
----------
data : tvm.relay.Expr
The input data is 4-D NCHW format.

kernel_size : Tuple[int]
Specified the slide window height and width.

dilation : Tuple[int]
Specifies the dilation rate for slide window.

padding : Tuple[int]
Specifies the padding size with top/bottom and left/right.

strides : Tuple[int]
Specifies the strides for slide window.

Returns
-------
result : tvm.relay.Expr
The computed result, 3-D NKL format, K is "C * kernel_height * kernel_width".

"""

return _make.im2col(data, kernel_size, dilation, padding, stride)
86 changes: 86 additions & 0 deletions src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,91 @@ RELAY_REGISTER_OP("nn.mirror_pad")
.add_type_rel("MirrorPad", MirrorPadRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);

Array<te::Tensor> Im2colCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<Im2colAttrs>();
ICHECK(param != nullptr);

return Array<te::Tensor>{
topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)};
}

bool Im2colRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [input, output]
ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";

masahi marked this conversation as resolved.
Show resolved Hide resolved
const auto* input = types[0].as<TensorTypeNode>();
if (input == nullptr) return false;

if (input->shape.size() != 4) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Im2lossRel: input data should be 4 dimensions, NxCxHxW.");
return false;
}

const Im2colAttrs* param = attrs.as<Im2colAttrs>();
if (param == nullptr) return false;

ICHECK_EQ(param->kernel_size.size(), 2) << "Expects two parameters for kernel height and width";
ICHECK_EQ(param->dilation.size(), 2) << "Expects two parameters for dilation height and width";
ICHECK_EQ(param->padding.size(), 2) << "Expects two parameters for padding height and width";
ICHECK_EQ(param->stride.size(), 2) << "Expects two parameters for stride height and width";

// Calculate output shape
auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]);
auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]);
auto dilation_h = tvm::cast(tvm::DataType::Int(32), param->dilation[0]);
auto dilation_w = tvm::cast(tvm::DataType::Int(32), param->dilation[1]);
auto padding_h = tvm::cast(tvm::DataType::Int(32), param->padding[0]);
auto padding_w = tvm::cast(tvm::DataType::Int(32), param->padding[1]);
auto stride_h = tvm::cast(tvm::DataType::Int(32), param->stride[0]);
auto stride_w = tvm::cast(tvm::DataType::Int(32), param->stride[1]);
Comment on lines +307 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we can use tir::as_const_int?

Copy link
Contributor Author

@delldu delldu Jul 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We checked source code and found that tvm::cast could deal with more complex situation than tir::as_const_int.
Not sure, I am new comer for TVM development, please feel free to figure out my mistake ^-^, thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you need all these casts. kernel_size etc should be already int. Supporting script mode is not an acceptable excuse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copy from PadCompute.

Copy link
Contributor Author

@delldu delldu Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, let me paste the source code.

bool Im2colRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
               const TypeReporter& reporter) {
  // types: [input, output]
  ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";

  const auto* input = types[0].as<TensorTypeNode>();
  if (input == nullptr) return false;

  if (input->shape.size() != 4) {
    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
                                     << "Im2lossRel: input data should be 4 dimensions, NxCxHxW.");
    return false;
  }

  const Im2colAttrs* param = attrs.as<Im2colAttrs>();
  if (param == nullptr) return false;

  ICHECK_EQ(param->kernel_size.size(), 2) << "Expects two parameters for kernel height and width";
  ICHECK_EQ(param->dilation.size(), 2) << "Expects two parameters for dilation height and width";
  ICHECK_EQ(param->padding.size(), 2) << "Expects two parameters for padding height and width";
  ICHECK_EQ(param->stride.size(), 2) << "Expects two parameters for stride height and width";

  // Calculate output shape
  auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]);
  auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]);
  auto dilation_h = tvm::cast(tvm::DataType::Int(32), param->dilation[0]);
  auto dilation_w = tvm::cast(tvm::DataType::Int(32), param->dilation[1]);
  auto padding_h = tvm::cast(tvm::DataType::Int(32), param->padding[0]);
  auto padding_w = tvm::cast(tvm::DataType::Int(32), param->padding[1]);
  auto stride_h = tvm::cast(tvm::DataType::Int(32), param->stride[0]);
  auto stride_w = tvm::cast(tvm::DataType::Int(32), param->stride[1]);
  auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1;
  auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1;
  // Output size after padding
  auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h) / stride_h + 1;
  auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w) / stride_w + 1;

  tvm::Array<tvm::PrimExpr> output_shape;
  output_shape.push_back(input->shape[0]);                        // N
  output_shape.push_back(input->shape[1] * kernel_h * kernel_w);  // K
  output_shape.push_back(output_h * output_w);                    // L

  // assign output type
  reporter->Assign(types[1], TensorType(output_shape, input->dtype));

  return true;
}
  1. RELAY_REGISTER_OP need Im2col with interface: bool Im2colRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter);
  2. We must parse attrs for checking they are valid or not;
  3. We must calculate output shape depend on parse results.
    Unfortunately, attrs are not integers, so we have to do such stupid things.

auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1;
auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1;
// Output size after padding
auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h) / stride_h + 1;
auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w) / stride_w + 1;

tvm::Array<tvm::PrimExpr> output_shape;
output_shape.push_back(input->shape[0]); // N
output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K
output_shape.push_back(output_h * output_w); // L

// assign output type
reporter->Assign(types[1], TensorType(output_shape, input->dtype));

return true;
}

// Handler to create a call to the im2col op used by front-end FFI
Expr MakeIm2col(Expr data, Array<IndexExpr> kernel_size, Array<IndexExpr> dilation,
Array<IndexExpr> padding, Array<IndexExpr> stride) {
auto attrs = make_object<Im2colAttrs>();

attrs->kernel_size = std::move(kernel_size);
attrs->dilation = std::move(dilation);
attrs->padding = std::move(padding);
attrs->stride = std::move(stride);

static const Op& op = Op::Get("nn.im2col");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_NODE_TYPE(Im2colAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.im2col").set_body_typed(MakeIm2col);

RELAY_REGISTER_OP("nn.im2col")
.describe(R"code(Im2col for 4-D NCHW tensor.

)code" TVM_ADD_FILELINE)
.set_attrs_type<Im2colAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(2)
.add_type_rel("Im2col", Im2colRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FTVMCompute>("FTVMCompute", Im2colCompute);

} // namespace relay
} // namespace tvm
Loading