Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BN] Enable NHWC in OCL #3399

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open

[BN] Enable NHWC in OCL #3399

wants to merge 13 commits into from

Conversation

bghimireamd
Copy link
Contributor

@bghimireamd bghimireamd commented Nov 20, 2024

  • Enable NHWC for Batch norm forward infer
  • Initialize the driver and gtest with similar range of values.

Comment on lines 55 to 59
if((!(n < 3) &&
!((in_nhw < 33554432 && in_cstride > 1024) ||
((n >= 256) && (in_cstride > 60) && bfpmixparm) || ((in_cstride > 512) && bfpmixparm)) &&
!(in_cstride <= 512)) ||
!((n > 768) && (in_cstride > 150) && bfp32parm))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's barely readable and probably redundant condition.
May I ask you to do the math and simplify it? (at least replace !(n < 3) with (n >= 3), but there are more simplifications possible)

Comment on lines 49 to 50
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that complex as the condition from src/solver/batchnorm/forward_spatial_multiple.cpp, but still can be simplified, since it contains redundant statements.

Copy link
Contributor

@CAHEK7 CAHEK7 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example of the transformations:

X == in_nhw < (32 * 1024 * 1024)
Y == in_cstride > 1024
Z == in_cstride > 512

!(X & Y) & !(X & Z) & !(!Z)
(!X | !Y) & (!X | !Z) & Z
(!X | !Y) & (!X & Z | !Z & Z) // !Z & Z -> false
(!X | !Y) & !X & Z
!X & Z | !Y & !X & Z
!X & Z & (1 | !Y) // (1 | !Y) -> true
!X & Z

So it basically means return (in_nhw >= (32 * 1024 * 1024)) && (in_cstride > 512);
Which probably can be simplified more, since in_nhw is n * in_cstride and we know for sure that in_cstride must be greater than 512.

If there are any doubts about those transformations, here is a proof (tested in excel, lol):
Using that fact that if Y is true, then Z must always be true we can even exclude few cases:

X Y Z old new result
0 0 0 FALSE FALSE TRUE
1 0 0 FALSE FALSE TRUE
0 0 1 TRUE TRUE TRUE
1 0 1 FALSE FALSE TRUE
0 1 1 TRUE TRUE TRUE
1 1 1 FALSE FALSE TRUE

@@ -38,9 +38,34 @@ namespace solver {

namespace batchnorm {

bool BNBwdIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem)
{
int n, c, h, w;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int n, c, h, w;
size_t n, c, h, w;

Comment on lines 46 to 47
unsigned int in_cstride = h * w;
unsigned int in_nhw = n * in_cstride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned int in_cstride = h * w;
unsigned int in_nhw = n * in_cstride;
size_t in_cstride = h * w;
size_t in_nhw = n * in_cstride;

Comment on lines 43 to 46
int n, c, h, w;
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths());

unsigned int in_cstride = h * w;
Copy link
Contributor

@CAHEK7 CAHEK7 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But actually you can avoid all the prevois int-related comment by the following code:

Suggested change
int n, c, h, w;
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths());
unsigned int in_cstride = h * w;
auto [n, c, h, w] = tien<4>(problem.GetXDesc().GetLengths());
auto in_cstride = problem.GetXDesc().GetStrides()[1];

Comment on lines 46 to 49
int n, c, h, w;
std::tie(n, c, h, w) = tien<4>(xDesc.GetLengths());
unsigned int in_cstride = h * w;
unsigned int in_nhw = n * in_cstride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments from src/solver/batchnorm/backward_spatial_multiple.cpp

Comment on lines 49 to 56
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512))
{
return true;
}
else
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512))
{
return true;
}
else
return false;
}
return !(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants