-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gather op bug #19168
Fix gather op bug #19168
Conversation
test=develop
paddle/fluid/operators/gather.cu.h
Outdated
@@ -51,7 +51,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, | |||
const Tensor& index, Tensor* output) { | |||
// PADDLE_ENFORCE(platform::is_gpu_place(place)); | |||
// check index of shape 1-D | |||
PADDLE_ENFORCE(index.dims().size() == 1 || | |||
PADDLE_ENFORCE((index.dims().size() == 1 && index.dims()[0] > 0) || | |||
(index.dims().size() == 2 && index.dims()[1] == 1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (index.dims().size() == 1)
PADDLE_ENFORCE_GT(index.dims()[0], 0, "error_msg");
else if (index.dims().size() == 2)
PADDLE_ENFORCE_EQ(index.dims()[1], 1, "error_msg");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/fluid/operators/gather_op.cu
Outdated
@@ -41,6 +41,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { | |||
paddle::framework::DataTypeToString(index_type), | |||
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), | |||
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); | |||
PADDLE_ENFORCE_GT(index->numel(), 0, "The index is empty."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error msg is too simple. The index of gather_op should not be empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
test=develop
20a2af2
to
163dd3c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个文件中其他PADDLE_ENFORCE也需要改掉,不然CI过不了。
PADDLE_ENFORCE_GT(index.dims()[0], 0, "The index should not be empty."); | ||
} else if (index.dims().size() == 2) { | ||
PADDLE_ENFORCE_EQ(index.dims()[1], 1, "The index should be 1-D."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 有了56-60行,52,54,55行可以去掉了。
- 57行,59行的报错信息需要增强。举例:
The index of gather_op should not be empty when index.dims().size() == 1.
The index of gather_op should be 1-D when index.dims().size() == 2.
1-D是1还是1维?
test=develop
fe06260
to
8e4ceb9
Compare
paddle/fluid/operators/gather.cu.h
Outdated
if (index.dims().size() == 1) { | ||
PADDLE_ENFORCE_GT(index.dims()[0], 0, | ||
"The index of gather_op should not be empty when " | ||
"index.dims().size() == 1."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think index.dims().size() == 1
isn't a good description, maybe when the dimension size of index equals 1.
is better
paddle/fluid/operators/gather.cu.h
Outdated
} else if (index.dims().size() == 2) { | ||
PADDLE_ENFORCE_EQ( | ||
index.dims()[1], 1, | ||
"If index.dims().size() is 2, the seconde dim should be 1."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seconde
-> second
, maybe the whole sentence can be polish to If the dimension size of index is 2, the second dimension should be equal to 2.
test=develop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Related issue #19085.
The return of
DDim::size()
is the rank info, this is to say, if thetensor
is empty, thetensor.dims().size()
is 1.Paddle/paddle/fluid/framework/ddim.h
Lines 61 to 62 in 0019eb3