-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] Allow conv definition to have custom kernel layout #11936
[TOPI] Allow conv definition to have custom kernel layout #11936
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, but I don't see any tests for different kernel_layouts. Could you add some?
@tkonolige There is no good ways to test numerical results for arbitrary combinations of layouts due to lack of schedule. Therefore, I added some tests to check |
out_dtype="float32", | ||
), | ||
], | ||
[A, W, topi.nn.conv2d_nhwc(A, W, stride, padding, dilation, "float32")], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
topi.nn.conv2d_nhwc
just calls topi.nn.conv2d
under the hood, so this test isn't really checking anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
topi.nn.conv2d_nhwc
is already covered by existing test cases, but I didn't find any test cases invoke topi.nn.conv2d
so we just need to check they are the same
Under the hood, both topi.nn.conv2d_nhwc
and topi.nn.conv2d
calls topi.nn.conv
and cover the code path with explicit kernel layout.
@vinx13 I think you can use |
c8daa85
to
1f1b428
Compare
python/tvm/topi/nn/conv2d.py
Outdated
@@ -57,16 +57,18 @@ | |||
) | |||
|
|||
|
|||
def conv2d(input, filter, strides, padding, dilation, layout="NCHW", out_dtype=None): | |||
def conv2d( | |||
input, filter, strides, padding, dilation, data_layout="NCHC", kernel_layout="", out_dtype=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe data_layout="NCHW" ?
layout of data | ||
|
||
kernel_layout : Optional[str] | ||
layout of kernel. If unspecified, use default layout inferred from data_layout. "OHWI" if | ||
data_layout == "NCHW", "HWIO" if data_layout == "NHWC". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little bit confused to me, shouldn’t "OIHW" & "NCHW" a default pair? I can also see this pattern at line 248.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for adding the tests!
cf01ebb
to
249be49
Compare
249be49
to
a9eb3d4
Compare
* [TOPI] Allow conv definition to have custom kernel layout * add tests * fix * fix
* [TOPI] Allow conv definition to have custom kernel layout * add tests * fix * fix
This PR added an optional parameter
kernel_layout
toconv
. This allows arbitrary data and kernel layout combination to be lowered from Relay, and scheduled by auto / meta schedule.cc @junrushao1994 @masahi @tkonolige