-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a batch norm inference kernel. #3309
Conversation
cudnn lib有bug,在cudnn 5.1上 n > 1024时出错,可以使用下面代码验证:
|
paddle/cuda/src/hl_batch_norm.cu
Outdated
size_t height, | ||
size_t width) { | ||
dim3 block(256, 1); | ||
dim3 grid(1, batchSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gird(batchSize, 1)
is better,Maximum x-dimension is 2^32 - 1, Maximum y dimension is 65536.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/cuda/src/hl_batch_norm.cu
Outdated
size_t channel, | ||
size_t height, | ||
size_t width) { | ||
const int tid = blockIdx.x * blockDim.x + threadIdx.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blockIdx.x * blockDim.x
can be removed, blockIdx.x is always equal 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/cuda/src/hl_batch_norm.cu
Outdated
const int num = channel * height * width; | ||
const int batch = blockIdx.y; | ||
for (int i = tid; i < num; i += blockDim.x) { | ||
const int c = (i / (height * width)) % channel; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove % channel
, i / (height * width) is smaller than the channel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
movingVar, | ||
EPS); | ||
if (batchSize > 1024) { | ||
// there is a bug in cudnn library when the batch size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some places say this is a limitation of CUDNN, not bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #929