-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BN] Enable NHWC in OCL #3399
base: develop
Are you sure you want to change the base?
[BN] Enable NHWC in OCL #3399
Conversation
bghimireamd
commented
Nov 20, 2024
•
edited
Loading
edited
- Enable NHWC for Batch norm forward infer
- Initialize the driver and gtest with similar range of values.
if((!(n < 3) && | ||
!((in_nhw < 33554432 && in_cstride > 1024) || | ||
((n >= 256) && (in_cstride > 60) && bfpmixparm) || ((in_cstride > 512) && bfpmixparm)) && | ||
!(in_cstride <= 512)) || | ||
!((n > 768) && (in_cstride > 150) && bfp32parm)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's barely readable and probably redundant condition.
May I ask you to do the math and simplify it? (at least replace !(n < 3)
with (n >= 3)
, but there are more simplifications possible)
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) && | ||
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not that complex as the condition from src/solver/batchnorm/forward_spatial_multiple.cpp, but still can be simplified, since it contains redundant statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One example of the transformations:
X == in_nhw < (32 * 1024 * 1024)
Y == in_cstride > 1024
Z == in_cstride > 512
!(X & Y) & !(X & Z) & !(!Z)
(!X | !Y) & (!X | !Z) & Z
(!X | !Y) & (!X & Z | !Z & Z) // !Z & Z -> false
(!X | !Y) & !X & Z
!X & Z | !Y & !X & Z
!X & Z & (1 | !Y) // (1 | !Y) -> true
!X & Z
So it basically means return (in_nhw >= (32 * 1024 * 1024)) && (in_cstride > 512);
Which probably can be simplified more, since in_nhw
is n * in_cstride
and we know for sure that in_cstride
must be greater than 512.
If there are any doubts about those transformations, here is a proof (tested in excel, lol):
Using that fact that if Y is true, then Z must always be true we can even exclude few cases:
X | Y | Z | old | new | result |
---|---|---|---|---|---|
0 | 0 | 0 | FALSE | FALSE | TRUE |
1 | 0 | 0 | FALSE | FALSE | TRUE |
0 | 0 | 1 | TRUE | TRUE | TRUE |
1 | 0 | 1 | FALSE | FALSE | TRUE |
0 | 1 | 1 | TRUE | TRUE | TRUE |
1 | 1 | 1 | FALSE | FALSE | TRUE |
@@ -38,9 +38,34 @@ namespace solver { | |||
|
|||
namespace batchnorm { | |||
|
|||
bool BNBwdIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem) | |||
{ | |||
int n, c, h, w; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int n, c, h, w; | |
size_t n, c, h, w; |
unsigned int in_cstride = h * w; | ||
unsigned int in_nhw = n * in_cstride; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsigned int in_cstride = h * w; | |
unsigned int in_nhw = n * in_cstride; | |
size_t in_cstride = h * w; | |
size_t in_nhw = n * in_cstride; |
int n, c, h, w; | ||
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths()); | ||
|
||
unsigned int in_cstride = h * w; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But actually you can avoid all the prevois int
-related comment by the following code:
int n, c, h, w; | |
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths()); | |
unsigned int in_cstride = h * w; | |
auto [n, c, h, w] = tien<4>(problem.GetXDesc().GetLengths()); | |
auto in_cstride = problem.GetXDesc().GetStrides()[1]; |
int n, c, h, w; | ||
std::tie(n, c, h, w) = tien<4>(xDesc.GetLengths()); | ||
unsigned int in_cstride = h * w; | ||
unsigned int in_nhw = n * in_cstride; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comments from src/solver/batchnorm/backward_spatial_multiple.cpp
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) && | ||
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512)) | ||
{ | ||
return true; | ||
} | ||
else | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) && | |
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512)) | |
{ | |
return true; | |
} | |
else | |
return false; | |
} | |
return !(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) && | |
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512); |