Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 7th] 修复 vctkernie_sat 训练时出现的类型提升问题 #3943

Merged
merged 2 commits into from
Dec 9, 2024

Conversation

megemini
Copy link
Contributor

@megemini megemini commented Dec 9, 2024

PR types

Bug fixes

PR changes

Others

Describe

修复 vctkernie_sat 训练时出现的类型提升问题:

aistudio@jupyter-942478-8626068:~/PaddleSpeech/examples/vctk/ernie_sat$ CUDA_VISIBLE_DEVICES=0,1 ./local/train.sh conf/default.yaml ./output
...

Traceback (most recent call last):
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddle/distributed/spawn.py", line 385, in _func_wrapper
    result = func(*args)
  File "/home/aistudio/PaddleSpeech/paddlespeech/t2s/exps/ernie_sat/train.py", line 167, in train_sp
    trainer.run()
  File "/home/aistudio/PaddleSpeech/paddlespeech/t2s/training/trainer.py", line 203, in run
    six.reraise(*exc_info)
  File "/usr/lib/python3/dist-packages/six.py", line 703, in reraise
    raise value
  File "/home/aistudio/PaddleSpeech/paddlespeech/t2s/training/trainer.py", line 149, in run
    update()
  File "/home/aistudio/PaddleSpeech/paddlespeech/t2s/training/updaters/standard_updater.py", line 110, in update
    self.update_core(batch)
  File "/home/aistudio/PaddleSpeech/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py", line 70, in update_core
    mlm_loss, text_mlm_loss = self.criterion(
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1531, in __call__
    return self.forward(*inputs, **kwargs)
  File "/home/aistudio/PaddleSpeech/paddlespeech/t2s/modules/losses.py", line 1121, in forward
    mlm_loss = paddle.sum((loss * paddle.reshape(
TypeError: (InvalidType) Type promotion only support calculations between floating-point numbers and between complex and real numbers. But got different data type x: float32, y: bool. (at ../paddle/phi/common/type_promotion.h:234)

@zxcd @Liyulingyue @GreatV @enkilee @yinfan98

Copy link

paddle-bot bot commented Dec 9, 2024

Thanks for your contribution!

@mergify mergify bot added the T2S label Dec 9, 2024
@@ -1115,7 +1115,8 @@ def forward(
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
mlm_loss_pos.astype(loss.dtype),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert mlm_loss_pos dtype first to reduce duplicate codes.

@megemini megemini requested a review from zxcd December 9, 2024 12:16
Copy link
Collaborator

@zxcd zxcd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zxcd zxcd merged commit e4038b4 into PaddlePaddle:develop Dec 9, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants