-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove paddle.where in cross_tropy_loss #38456
Remove paddle.where in cross_tropy_loss #38456
Conversation
✅ This PR's description meets the template requirements! |
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以描述下删掉的这些check的功能和删掉之后有啥影响吗?为什么之前需要加上这个checking,现在直接删掉的时候,出现这些case是有其它代码会检查吗?
去掉这个检查后,如果labels范围超出边界,则会在gather_nd这里报错 def test_LabelValue_ExceedMax():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64") # hard label
label_data[0] = 100
weight_data = paddle.rand([100]) # provide weight
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
Error: /paddle/paddle/fluid/operators/gather.cu.h:62 Assertion `index_value >= 0 && index_value < input_dims[j]` failed. The index is out of bounds, please check whether the dimensions of index and input meet the requirements. It should be less than [100] and greater than or equal to 0, but received [0] def test_LabelValue_ExceedMin():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64") # hard label
label_data[0] = -1
weight_data = paddle.rand([100]) # provide weight
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
Error: /paddle/paddle/fluid/operators/gather.cu.h:62 Assertion `index_value >= 0 && index_value < input_dims[j]` failed. The index is out of bounds, please check whether the dimensions of index and input meet the requirements. It should be less than [100] and greater than or equal to 0, but received [0] |
这样是不是就不容易判断出是cross_entropy的报错了啊,有python栈可以看出是cross_entropy的问题吗? |
Sorry to inform you that 056394c's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
raise ValueError( | ||
"Target({}) is out of class_dimension's upper bound({})". | ||
format(invalid_label[0], input.shape[axis] - 1)) | ||
valid_label = paddle.cast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前讨论过在weight is not None
的条件下,才需要判断label的合法性。需要把检查移动到weight is not None
中吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于softmax_with_cross_entropy
中没有对labels进行充分的合法性检查,所以这里在hard label条件下就判断label的合法性,无论是否用了weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LGTM |
PR types
Function optimization
PR changes
APIs
Describe