Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gather op bug #19168

Merged
merged 4 commits into from
Aug 14, 2019
Merged

Conversation

chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented Aug 13, 2019

Related issue #19085.
The return of DDim::size() is the rank info, this is to say, if the tensor is empty, the tensor.dims().size() is 1.

DDim() : rank_(1) { dim_[0] = 0; }

test=develop
@@ -51,7 +51,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 ||
PADDLE_ENFORCE((index.dims().size() == 1 && index.dims()[0] > 0) ||
(index.dims().size() == 2 && index.dims()[1] == 1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (index.dims().size() == 1)
  PADDLE_ENFORCE_GT(index.dims()[0], 0, "error_msg");
else if (index.dims().size() == 2)
  PADDLE_ENFORCE_EQ(index.dims()[1], 1, "error_msg");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -41,6 +41,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_GT(index->numel(), 0, "The index is empty.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error msg is too simple. The index of gather_op should not be empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

test=develop
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个文件中其他PADDLE_ENFORCE也需要改掉,不然CI过不了。

PADDLE_ENFORCE_GT(index.dims()[0], 0, "The index should not be empty.");
} else if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, "The index should be 1-D.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 有了56-60行,52,54,55行可以去掉了。
  2. 57行,59行的报错信息需要增强。举例:
The index of gather_op should not be empty when index.dims().size() == 1.
The index of gather_op should be 1-D when index.dims().size() == 2.

1-D是1还是1维?

test=develop
@chengduoZH chengduoZH requested a review from kuke August 14, 2019 07:38
if (index.dims().size() == 1) {
PADDLE_ENFORCE_GT(index.dims()[0], 0,
"The index of gather_op should not be empty when "
"index.dims().size() == 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think index.dims().size() == 1 isn't a good description, maybe when the dimension size of index equals 1. is better

} else if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1], 1,
"If index.dims().size() is 2, the seconde dim should be 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seconde -> second, maybe the whole sentence can be polish to If the dimension size of index is 2, the second dimension should be equal to 2.

test=develop
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chengduoZH chengduoZH merged commit b5ba801 into PaddlePaddle:develop Aug 14, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants