Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast layer norm has non-deterministic problem #56100

Closed
zhaoyinglia opened this issue Aug 9, 2023 · 6 comments · Fixed by #56435
Closed

fast layer norm has non-deterministic problem #56100

zhaoyinglia opened this issue Aug 9, 2023 · 6 comments · Fixed by #56435
Assignees

Comments

@zhaoyinglia
Copy link
Contributor

zhaoyinglia commented Aug 9, 2023

bug描述 Describe the Bug

Hi @jeng1220 , we found GPT3-1.3B's loss is non-deterministic after #55639.
To reproduce, run this code in PaddleNLP twice and compare the loss value.

export FLAGS_embedding_deterministic=1
export FLAGS_cudnn_deterministic=1

python ./tools/train.py \
    -c ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml \
    -o Model.hidden_dropout_prob=0 \
    -o Model.attention_probs_dropout_prob=0 \
    -o Model.use_recompute=False \
    -o Global.local_batch_size=1 \
    -o Global.micro_batch_size=1 \
    -o Distributed.dp_degree=1 \
    -o Distributed.mp_degree=1 \
    -o Distributed.pp_degree=1 \
    -o Distributed.sharding.sharding_degree=1 \
    -o Distributed.sharding.sharding_stage=1 \
    -o Engine.mix_precision.enable=False \
    -o Engine.max_steps=5 \
    -o Engine.eval_freq=100000 \
    -o Engine.save_load.output_dir=""

There is a unit test that can also reproduce the error result, but the frequency is low.

import paddle
import numpy as np

paddle.seed(1234)
np.random.seed(1234)

arr = np.random.random([1, 1024, 2048])
tensor = paddle.to_tensor(arr, dtype="float32")
norm = paddle.nn.LayerNorm(2048, epsilon=1e-5)

print(np.sum(np.abs(np.array(norm(tensor)))))
print(np.sum(np.abs(np.array(norm(tensor)))))
image

其他补充信息 Additional Supplementary Information

No response

@jeng1220
Copy link
Collaborator

jeng1220 commented Aug 9, 2023

I will take a look tomorrow afternoon

@jeng1220
Copy link
Collaborator

jeng1220 commented Aug 11, 2023

The compute-sanitizer reports data race issue.
Need to dig further into the problem...

compute-sanitizer --tool racecheck  python foo.py
========= COMPUTE-SANITIZER
...
========= Error: Race reported between Read access at 0x1010 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:275:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *)
=========     and Write access at 0x1510 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:300:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *) [4096 hazards]
=========
========= Error: Race reported between Read access at 0xfc0 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:270:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *)
=========     and Write access at 0xff0 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:272:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *) [2048 hazards]
=========
========= Error: Race reported between Read access at 0x1530 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:307:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *)
=========     and Write access at 0x1590 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:309:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *) [2048 hazards]

@jeng1220
Copy link
Collaborator

This issue doesn't happen with V100 but does with A100, that's why CI didn't catch the issue (CI only has V100).

@zhaoyinglia
Copy link
Contributor Author

zhaoyinglia commented Aug 14, 2023

This issue doesn't happen with V100 but does with A100, that's why CI didn't catch the issue (CI only has V100).

This issue also happens in my env, V100&Cuda11.2. CI is V100&Cuda12, but I haven't tested it in CI env.

@zhiqiu
Copy link
Contributor

zhiqiu commented Aug 14, 2023

The compute-sanitizer reports data race issue. Need to dig further into the problem...

compute-sanitizer --tool racecheck  python foo.py
========= COMPUTE-SANITIZER
...
========= Error: Race reported between Read access at 0x1010 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:275:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *)
=========     and Write access at 0x1510 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:300:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *) [4096 hazards]
=========
========= Error: Race reported between Read access at 0xfc0 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:270:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *)
=========     and Write access at 0xff0 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:272:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *) [2048 hazards]
=========
========= Error: Race reported between Read access at 0x1530 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:307:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *)
=========     and Write access at 0x1590 in /home/scratch.rjeng_sw/baidu/paddle/paddle/paddle/phi/kernels/funcs/layer_norm_impl.cu.h:309:void phi::funcs::fast_ln_fwd_kernel<float, float, float, (int)4, (int)2, (int)2, (int)16, (int)2048, (int)32, (int)64, (int)128, (int)2, (int)256, (int)8>(int, int, float, const T1 *, const T3 *, const T3 *, T2 *, T2 *, T1 *) [2048 hazards]

I suggest we disable fast_layer_norm temperally, and reopen it after fix.

@jeng1220
Copy link
Collaborator

@onecatcn, @zhaoyinglia ^^^

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants