-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support aten::grid_sampler, aten::im2col #8443
Changes from 12 commits
4d23916
80b2262
fe46820
bc8e101
23df38b
1c64c44
6514f74
bc02b4c
d086e28
4935e65
303c70a
8798175
7a1e24e
f25c34b
c41c30e
c9704d9
fbf98d2
65bbf30
43c5d01
0387aef
7ca82a9
5667b94
97a5df6
8d88fc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -690,6 +690,130 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T | |
return T; | ||
} | ||
} | ||
|
||
/*! | ||
* \brief Creates an operation that performs im2col with an NCHW-layout | ||
* | ||
* \param data The 4-D input tensor | ||
* \param kernel_size A static tuple for kernel size, such as (3,3) | ||
* \param dilation A static tuple for dilation, default is (1,1) | ||
* \param padding A static tuple for padding, padding value is zero | ||
* \param stride A static tuple for strides, default is (1,1) | ||
masahi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* \param name The name of the operation | ||
* \param tag The tag to mark the operation | ||
* | ||
* \return A Tensor whose op member is the im2col operation (NCHW layout) | ||
|
||
Pseudo code: | ||
input_b, input_c, input_h, input_w = data_shape | ||
dilation_h, dilation_w = dilation | ||
padding_h, padding_w = padding | ||
|
||
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 | ||
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 | ||
|
||
output_h = (input_h + 2 * padding_h - dilated_kernel_h)//stride_h + 1 | ||
output_w = (input_w + 2 * padding_w - dilated_kernel_w)//stride_w + 1 | ||
|
||
h_offset = (k // input_w) % kernel_h | ||
w_offset = k % kernel_w | ||
|
||
im2col_data = te.compute( | ||
(N, K, L), | ||
lambda n, k, l: data[ | ||
n, | ||
k / kernel_h / kernel_w, | ||
stride_h * (l / output_w) + dilation_h * h_offset - padding_h, | ||
stride_w * (l % output_w) + dilation_w * w_offset - padding_w, | ||
], | ||
name="im2col_data", | ||
) | ||
*/ | ||
inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, | ||
const tvm::Array<tvm::PrimExpr>& kernel_size, | ||
const tvm::Array<tvm::PrimExpr>& dilation, | ||
const tvm::Array<tvm::PrimExpr>& padding, | ||
const tvm::Array<tvm::PrimExpr>& stride, | ||
std::string name = "T_im2col", std::string tag = kElementWise) { | ||
ICHECK_EQ(4, data->shape.size()); | ||
ICHECK_EQ(2, kernel_size.size()); | ||
ICHECK_EQ(2, dilation.size()); | ||
ICHECK_EQ(2, padding.size()); | ||
ICHECK_EQ(2, stride.size()); | ||
|
||
auto input_b = data->shape[0]; | ||
auto input_c = data->shape[1]; | ||
auto input_h = data->shape[2]; | ||
auto input_w = data->shape[3]; | ||
|
||
auto kernel_h = tvm::cast(tvm::DataType::Int(32), kernel_size[0]); | ||
auto kernel_w = tvm::cast(tvm::DataType::Int(32), kernel_size[1]); | ||
|
||
auto dilation_h = tvm::cast(tvm::DataType::Int(32), dilation[0]); | ||
auto dilation_w = tvm::cast(tvm::DataType::Int(32), dilation[1]); | ||
|
||
auto padding_h = tvm::cast(tvm::DataType::Int(32), padding[0]); | ||
auto padding_w = tvm::cast(tvm::DataType::Int(32), padding[1]); | ||
|
||
auto stride_h = tvm::cast(tvm::DataType::Int(32), stride[0]); | ||
auto stride_w = tvm::cast(tvm::DataType::Int(32), stride[1]); | ||
Comment on lines
+749
to
+759
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explained as before. tvm::cast would be a better choice than tir::as_const_int if we run script mode, PyTorch script mode could generate complex expression. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't really support script mode. Please don't make your implementation complicated to support script mode. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just copy from PadCompute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reference on same file about 155~245 lines.
with same function tvm::cast to parsing parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks. It is real good reference. High value ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no need to copy old code. |
||
|
||
auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; | ||
auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; | ||
|
||
// Output size after padding | ||
auto output_h = (input_h + 2 * padding_h - dilated_kernel_h) / stride_h + 1; | ||
auto output_w = (input_w + 2 * padding_w - dilated_kernel_w) / stride_w + 1; | ||
|
||
// Result output size | ||
auto N = input_b; | ||
auto K = input_c * kernel_h * kernel_w; | ||
auto L = output_h * output_w; | ||
tvm::Array<tvm::PrimExpr> output_shape; | ||
output_shape.push_back(N); | ||
output_shape.push_back(K); | ||
output_shape.push_back(L); | ||
|
||
auto pad_value = tvm::tir::make_const(data->dtype, 0); | ||
|
||
auto l = [&](tvm::Array<tvm::tir::Var> args) { | ||
tvm::tir::Var n = args[0]; | ||
tvm::tir::Var k = args[1]; | ||
tvm::tir::Var l = args[2]; | ||
|
||
tvm::Array<tvm::PrimExpr> indices; | ||
tvm::Array<tvm::PrimExpr> condition; | ||
|
||
indices.push_back(n); // B, souce batch | ||
|
||
// source chanel s_c = k / kernel_h / kernel_w | ||
tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w); | ||
indices.push_back(s_c); // C, source channel | ||
|
||
// (k / kernel_w) % kernel_h | ||
// stride_h * (l / output_w) + dilation_h * h_offset, | ||
tvm::PrimExpr h_offset = indexmod(indexdiv(k, kernel_w), kernel_h); | ||
tvm::PrimExpr s_h = stride_h * indexdiv(l, output_w) + dilation_h * h_offset - padding_h; | ||
indices.push_back(s_h); // H, souce height | ||
condition.push_back(s_h >= 0); | ||
condition.push_back(s_h < input_h); | ||
|
||
// k % kernel_w; | ||
// stride_w * (l % output_w) + dilation_w * w_offset, | ||
tvm::PrimExpr w_offset = indexmod(k, kernel_w); | ||
tvm::PrimExpr s_w = stride_w * indexmod(l, output_w) + dilation_w * w_offset - padding_w; | ||
indices.push_back(s_w); // W, source width | ||
condition.push_back(s_w >= 0); | ||
condition.push_back(s_w < input_w); | ||
|
||
return tvm::if_then_else( | ||
foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, | ||
const_true(1), condition), | ||
data(indices), pad_value); | ||
}; | ||
|
||
return tvm::te::compute(output_shape, l, name, tag); | ||
} | ||
} // namespace topi | ||
} // namespace tvm | ||
#endif // TVM_TOPI_NN_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1875,7 +1875,9 @@ def Bool(self, inputs, input_types): | |
|
||
def Float(self, inputs, input_types): | ||
assert len(inputs) == 1 | ||
return _op.cast(inputs[0], "float32") | ||
if isinstance(inputs[0], _expr.Expr): | ||
return inputs[0] | ||
return float(inputs[0]) | ||
masahi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def bitwise_not(self, inputs, input_types): | ||
data = inputs[0] | ||
|
@@ -2329,6 +2331,30 @@ def flip(self, inputs, input_types): | |
axis = inputs[1] | ||
return _op.transform.reverse(data, axis=axis[0]) | ||
|
||
def grid_sampler(self, inputs, input_types): | ||
data = inputs[0] | ||
grid = inputs[1] | ||
|
||
# Torch grid shape is like [batch, out_height, out_width, 2], but | ||
# TVM grid is [batch, 2, out_height, out_width], so here grid need to be converted | ||
grid = _op.transform.transpose(grid, axes=[0, 3, 1, 2]) | ||
return _op.image.grid_sample(data, grid, method="bilinear", layout="NCHW") | ||
|
||
def im2col(self, inputs, input_types): | ||
r""" | ||
Torch F.unfold set kerenl_size, dilation, padding, stride as pairs before calling im2col | ||
but it brokern TVM "if condition expression", so please USE torch._C._nn.im2col instead | ||
of F.unfold and make sure giving paired parameters. Please reference test_forward_im2col | ||
in file tests/python/frontend/pytorch/test_forward.py. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this sentences. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From [https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#unfold]:
You can find:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed usefulness comments in source code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First of all, please fix the typo and grammar error, otherwise the comments do not make sense. I don't understand what is "if condition expression". Please revisit this problem assuming you don't need to support script mode. |
||
""" | ||
data = inputs[0] | ||
kernel_size = inputs[1] | ||
dilation = inputs[2] | ||
padding = inputs[3] | ||
stride = inputs[4] | ||
|
||
return _op.nn.im2col(data, kernel_size, dilation, padding, stride) | ||
|
||
# Operator mappings | ||
def create_convert_map(self): | ||
self.convert_map = { | ||
|
@@ -2545,6 +2571,8 @@ def create_convert_map(self): | |
"aten::nll_loss": self.nll_loss, | ||
"aten::nll_loss2d": self.nll_loss, | ||
"aten::flip": self.flip, | ||
"aten::grid_sampler": self.grid_sampler, | ||
"aten::im2col": self.im2col, | ||
} | ||
|
||
def update_convert_map(self, custom_map): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1308,3 +1308,6 @@ def dilate_shape_func(attrs, inputs, _): | |
reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) | ||
reg.register_shape_func("nn.softmax", False, elemwise_shape_func) | ||
reg.register_shape_func("nn.relu", False, elemwise_shape_func) | ||
|
||
|
||
reg.register_broadcast_schedule("nn.im2col") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should not be broadcast if you set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we change kOpaque to kInjective ? Frankly, I do not understand it. please teach me. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, we have change kOpaque to kInjective, and test passed. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3646,3 +3646,38 @@ def batch_to_space_nd(data, block_shape, crops): | |
""" | ||
|
||
return _make.batch_to_space_nd(data, block_shape, crops) | ||
|
||
|
||
def im2col(data, kernel_size, dilation, padding, stride): | ||
r"""im2col. | ||
|
||
This operator convert 4-D NCHW data with im2col for PyTorch unfold. | ||
|
||
Please reference | ||
`https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold` | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTorch specific stuff shouldn't be in this file. Please explain what this function does instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you click "[SOURCE]" link, you will find
torch.nn.Unfold forward will all F.unfold, F.unfold will call im2col. If you call F.unfold or nn.Unfold forward, from_pytorch will give you a exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed usefulness comments in source code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please just remove references to PyTorch and describe what it does. This is not frontend code. |
||
Parameters | ||
---------- | ||
data : tvm.relay.Expr | ||
The input data is 4-D NCHW format. | ||
|
||
kernel_size : Tuple[int] | ||
Specified the slide window height and width. | ||
|
||
dilation : Tuple[int] | ||
Specifies the dilation rate for slide window. | ||
|
||
padding : Tuple[int] | ||
Specifies the padding size with top/bottom and left/right. | ||
|
||
strides : Tuple[int] | ||
Specifies the strides for slide window. | ||
|
||
Returns | ||
------- | ||
result : tvm.relay.Expr | ||
The computed result, 3-D NKL format, K is "C * kernel_height * kernel_width". | ||
|
||
""" | ||
|
||
return _make.im2col(data, kernel_size, dilation, padding, stride) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,5 +272,86 @@ RELAY_REGISTER_OP("nn.mirror_pad") | |
.add_type_rel("MirrorPad", MirrorPadRel) | ||
.set_attr<TOpPattern>("TOpPattern", kInjective); | ||
|
||
Array<te::Tensor> Im2colCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, | ||
const Type& out_type) { | ||
const auto* param = attrs.as<Im2colAttrs>(); | ||
ICHECK(param != nullptr); | ||
|
||
return Array<te::Tensor>{ | ||
topi::im2col(inputs[0], param->kernel_size, param->dilation, param->padding, param->stride)}; | ||
} | ||
|
||
bool Im2colRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, | ||
const TypeReporter& reporter) { | ||
// types: [input, output] | ||
ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; | ||
|
||
masahi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const auto* input = types[0].as<TensorTypeNode>(); | ||
if (input == nullptr) return false; | ||
|
||
if (input->shape.size() != 4) { | ||
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) | ||
<< "Im2lossRel: input data should be 4 dimensions, NxCxHxW."); | ||
return false; | ||
} | ||
|
||
const Im2colAttrs* param = attrs.as<Im2colAttrs>(); | ||
if (param == nullptr) return false; | ||
|
||
// Calculate output shape | ||
auto kernel_h = tvm::cast(tvm::DataType::Int(32), param->kernel_size[0]); | ||
auto kernel_w = tvm::cast(tvm::DataType::Int(32), param->kernel_size[1]); | ||
auto dilation_h = tvm::cast(tvm::DataType::Int(32), param->dilation[0]); | ||
auto dilation_w = tvm::cast(tvm::DataType::Int(32), param->dilation[1]); | ||
auto padding_h = tvm::cast(tvm::DataType::Int(32), param->padding[0]); | ||
auto padding_w = tvm::cast(tvm::DataType::Int(32), param->padding[1]); | ||
auto stride_h = tvm::cast(tvm::DataType::Int(32), param->stride[0]); | ||
auto stride_w = tvm::cast(tvm::DataType::Int(32), param->stride[1]); | ||
Comment on lines
+307
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We checked source code and found that tvm::cast could deal with more complex situation than tir::as_const_int. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why you need all these casts. kernel_size etc should be already int. Supporting script mode is not an acceptable excuse. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just copy from PadCompute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, let me paste the source code.
|
||
auto dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; | ||
auto dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; | ||
// Output size after padding | ||
auto output_h = (input->shape[2] + 2 * padding_h - dilated_kernel_h) / stride_h + 1; | ||
auto output_w = (input->shape[3] + 2 * padding_w - dilated_kernel_w) / stride_w + 1; | ||
|
||
tvm::Array<tvm::PrimExpr> output_shape; | ||
output_shape.push_back(input->shape[0]); // N | ||
output_shape.push_back(input->shape[1] * kernel_h * kernel_w); // K | ||
output_shape.push_back(output_h * output_w); // L | ||
|
||
// assign output type | ||
reporter->Assign(types[1], TensorType(output_shape, input->dtype)); | ||
|
||
return true; | ||
} | ||
|
||
// Handler to create a call to the im2col op used by front-end FFI | ||
Expr MakeIm2col(Expr data, Array<IndexExpr> kernel_size, Array<IndexExpr> dilation, | ||
Array<IndexExpr> padding, Array<IndexExpr> stride) { | ||
auto attrs = make_object<Im2colAttrs>(); | ||
|
||
attrs->kernel_size = std::move(kernel_size); | ||
attrs->dilation = std::move(dilation); | ||
attrs->padding = std::move(padding); | ||
attrs->stride = std::move(stride); | ||
|
||
static const Op& op = Op::Get("nn.im2col"); | ||
return Call(op, {data}, Attrs(attrs), {}); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(Im2colAttrs); | ||
TVM_REGISTER_GLOBAL("relay.op.nn._make.im2col").set_body_typed(MakeIm2col); | ||
|
||
RELAY_REGISTER_OP("nn.im2col") | ||
.describe(R"code(Im2col for 4-D NCHW tensor. | ||
|
||
)code" TVM_ADD_FILELINE) | ||
.set_attrs_type<Im2colAttrs>() | ||
.set_num_inputs(1) | ||
.add_argument("data", "Tensor", "Input data.") | ||
.set_support_level(2) | ||
.add_type_rel("Im2col", Im2colRel) | ||
.set_attr<TOpPattern>("TOpPattern", kOpaque) | ||
.set_attr<FTVMCompute>("FTVMCompute", Im2colCompute); | ||
|
||
} // namespace relay | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IndexExpr should be
Integer
. If you make this change you can remove all the casts.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry let give a little explain.
So RELAY_REGISTER_OP ==> Im2colCompute ==> im2col and Im2colAttrs, we almost have no choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why you can't use
Array<Integer>
. There is no need to followConv2DAttrs
, if all attributes in im2col ops are integer constant, there is no point usingArray<IndexExpr>
and it just makes your implementation unnecessarily complicated.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see. will modify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @masahi , @wyc-ruiker ,
First all, thank you help us improve PR. Sounds there is one import issue left, I would give my understanding.
Second, please remember
using IndexExpr = ::tvm::PrimExpr;
1. Conclusion
Changing
Array<IndexExpr>
toArray<Integer>
in Im2colAttrs is not good.2. Function Patch
inline tvm::te::Tensor im2col(const tvm::te::Tensor& data, ...)
{
auto input_b = data->shape[0];
...
auto kernel_h = tvm::cast(tvm::DataType::Int(32), kernel_size[0]);
...
tvm::Arraytvm::PrimExpr output_shape;
output_shape.push_back(N);
...
auto pad_value = tvm::tir::make_const(data->dtype, 0);
auto l = [&](tvm::Arraytvm::tir::Var args) {
...
tvm::PrimExpr s_c = indexdiv(indexdiv(k, kernel_h), kernel_w);
indices.push_back(s_c); // C, source channel
...
return tvm::if_then_else(..., data(indices), pad_value);
};
return tvm::te::compute(output_shape, l, name, tag);
}
3. Cause analysis
The shape of input data is
Array<PrimExpr>
;inline tvm::te::Tensor im2col(const tvm::te::Tensor& data
output shape is also
Array<PrimExpr>
.TVM_DLL Tensor compute(Array<PrimExpr> shape, ...);
3.1) Many expression in lambda function such as indices, are
Array<PrimExpr>
, NOTArray<Integer>
.3.2) Some function parameters are PrimExpr.
For an example: prototye of indexmod is:
If we give Integer parameters to indexdiv, only got compile fatal warning and errors ...
4. Test Result
We try to change
Array<IndexExpr>
toArray<Integer>
, modified related source code,fighting again and again, using many methods, finally got many and many building(compile) errors.
5. Inference
tvm::cast
is good choice thanas_const_int
for parsing parameters.The reason comes from their prototype
Following above results, that is why we can't use
Array<Integer>
.