Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error message of multinomial op #27946

Merged
merged 13 commits into from
Oct 19, 2020
Merged

Conversation

pangyoki
Copy link
Contributor

@pangyoki pangyoki commented Oct 14, 2020

PR types

Bug fixes

PR changes

OPs

Describe

paddle.multinomial(x, num_samples=1, replacement=False, name=None) refer to PR #27219 .
Optimize the error message of some special situations.

QA test bugs that need to add error message

  • num_sample <= 0
    Raise InvalidArgument Error and tell users that Number of samples should be > 0.

  • dimension of input x: dim_x <=0 or dim_x > 2
    Raise InvalidArgument Error and tell users that Input probability distribution should be 1 or 2 dimension.

error behavior of CUDA Kernel is not consistent with CPU Kernel

  • value of element of x < 0
    Because we need to calculate probabilities of distribution from x. It can't be less than 0.
    In CUDA Kernel, Enforce x >= 0 and tell users that The input of multinomial distribution should be >= 0.

  • all of the elements of x are 0
    Sum of elements of x should be >0.

By the way, fix multinomial and bernoulli python API's doc. Add the indentation.
Fix Categorical class's doc. Add the indentation, add attr name, change Variable to Tensor.
Add paddle.manual_seed in sample code.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -53,12 +53,18 @@ class MultinomialOp : public framework::OperatorWithKernel {

auto x_dim = ctx->GetInputDim("X");
int64_t x_rank = x_dim.size();
PADDLE_ENFORCE_EQ(
x_rank > 0 && x_rank <= 2, true,
Copy link
Contributor

@zhiqiu zhiqiu Oct 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use PADDLE_ENFORCE_GT and PADDLE_ENFORCE_LE instead, do not combine two checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

std::vector<int64_t> out_dims(x_rank);
for (int64_t i = 0; i < x_rank - 1; i++) {
out_dims[i] = x_dim[i];
}

int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
PADDLE_ENFORCE_GT(num_samples, 0, platform::errors::OutOfRange(
"Number of samples should be > 0"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Number of samples should be > 0"));
"The number of samples should be > 0, but got %d.", num_samples ));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -31,6 +32,14 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above, print the actual data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ENFORCE(in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0");
PADDLE_ENFORCE(
!std::isinf(static_cast<double>(in_data[id])) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not combine several logical expressions in one ENFORCE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pangyoki pangyoki mentioned this pull request Oct 14, 2020
"1 or 2 dimension, but got %d",
x_rank));
PADDLE_ENFORCE_LE(x_rank, 2, platform::errors::PreconditionNotMet(
"Input probability distribution should be "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Input probability distribution should be "
"The number of dimensions of the input probability distribution should be <= 2, but got %d."

Similar for the others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f",
in_data[id]);
PADDLE_ENFORCE(in_data[id] != INFINITY,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any special reason that checking INF/NaN is added here? Otherwise, I think it is not really necessary. Because the property that a number is not NAN or INF should be satisfied almost everywhere, and if we check it everywhere, it may slow down the system.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have removed checking INF/NaN

PADDLE_ENFORCE(in_data[id] != NAN,
"The input of multinomial distribution shoud not be NaN");
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of input should not be 0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean >0 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's > 0, and >0 has the same meaning with not be 0 here. Because <0 has been forbidden before.
I have change the description from not be 0 to >0.

Comment on lines 52 to 59
PADDLE_ENFORCE_EQ(
std::isinf(static_cast<double>(prob_value)), false,
platform::errors::OutOfRange(
"The input of multinomial distribution should be >= 0"));
PADDLE_ENFORCE_EQ((std::isinf(static_cast<double>(prob_value)) ||
std::isnan(static_cast<double>(prob_value))),
false, platform::errors::OutOfRange(
"The input of multinomial distribution "
"shoud not be infinity or NaN"));
"The input of multinomial distribution shoud not be infinity"));
PADDLE_ENFORCE_EQ(
std::isnan(static_cast<double>(prob_value)), false,
platform::errors::OutOfRange(
"The input of multinomial distribution shoud not be NaN"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# 0.09053693, 0.30820143, 0.19095989]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better add paddle.manual_seed(xx) here, otherwise, users cannot get the same random output as your sample code.
Same for all the other examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -31,6 +32,14 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(
in_data[id] >= 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议报错信息统一加句点,PR里有的加了,有的没加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for PADDLE_ENFORCE

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@zhiqiu zhiqiu merged commit 975bd88 into PaddlePaddle:develop Oct 19, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants