-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
解决已有issues共性问题,包括逐条推理不一致、score nan、半精度预测不一致等问题 #291
Comments
测试下来,不能解决报错:RuntimeError: probability tensor contains either inf, nan or element < 0 |
13B的代码和7B的不一样,13B的代码修改完后仍不能解决推理不一致问题 |
补充一下信息,在bf16/fp16/INT4/INT8下不同的batchsize结果不同,fp32下无此问题 |
我试了下,好像还是有问题。我用单条数据人为添加n个pad token以及attention mask来模拟batch生成时候的形式。 |
部分情况下,model会被转成fp32,但input还是半精度,这时候就会有type mismatch。 可以在234那边增加两行 if(attention_mask.dtype != query_states.dtype):
attention_mask = attention_mask.to(query_states.dtype) 改完看起来是这样的 with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
if(attention_mask.dtype != query_states.dtype):
attention_mask = attention_mask.to(query_states.dtype)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask) |
问题代码在于百川2源码attention计算中,attention mask和 qkv的精度不一致导致,问题代码:https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/0cc6a61c06cd0734270151109d07cf86ef0ace53/modeling_baichuan.py#L234
attention_mask 在经过self._prepare_decoder_attention_mask 函数后会被cast成 self.embed_tokens 的dtype也就是float32,在使用
半精度操作时 (比如使用transformer Seq2SeqTrainingArguments 中的bf16,或者accelerator配置的fp16等等), float32的最小值-3.4028e+38 会变成 -inf。当[-inf, ... -inf] 全是 -inf的向量经过softmax时结果会是nan。由此会出现诸多问题。
解决方案,将attention_mask提前cast成和qkv一致的精度即可,避免进入softmax前是-inf:
看了下百川1自己实现的scaled_dot_product_attention 里写的反而是对着的...
The text was updated successfully, but these errors were encountered: