Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gpt2 train loss Nan problem by add a line __syncthreads in BlockR… #33658

Merged
merged 1 commit into from
Jun 22, 2021
Merged

Conversation

zhiboniu
Copy link
Contributor

@zhiboniu zhiboniu commented Jun 18, 2021

PR types

Bug fixes

PR changes

OPs

Describe

背景:
gpt2 训练过程中出现loss不稳定、不收敛、最终变成nan的情况。

经排查:
1)在P40上训练正常,V100上训练出现异常。
2)添加一行log打印训练正常,无log打印训练异常。
3)使用原线性相加方式训练正常,使用BlockReduceSum训练异常。

最终通过去掉static共享内存shared,同时添加一行__syncthreads后解决训练异常问题。
同时对另外两个BlockReduceSum加入__syncthreads以提高可靠性。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZHUI
Copy link
Collaborator

ZHUI commented Jun 18, 2021

GPT train loss will NaN since the pr #33420 .

d_scale_partial = BlockReduceSum<U>(d_scale_partial);
d_bias_partial = BlockReduceSum<U>(d_bias_partial);
__shared__ U shared_scale[32];
__shared__ U shared_bias[32];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用 kMaxBlockDim/lwarpSize 来代替32是否会更节省share memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已线下讨论,share memory整体空间足够大,可以保持32。

ForFishes
ForFishes previously approved these changes Jun 18, 2021
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

jeff41404
jeff41404 previously approved these changes Jun 18, 2021
ZHUI
ZHUI previously approved these changes Jun 21, 2021
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiboniu zhiboniu dismissed stale reviews from ZHUI, jeff41404, and ForFishes via 8e3bdd3 June 21, 2021 07:49
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@XiaoguangHu01 XiaoguangHu01 merged commit 687571f into PaddlePaddle:develop Jun 22, 2021
@zhiboniu zhiboniu deleted the develop_gpt branch September 27, 2021 10:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants