From 975bd8873bd07bc15a5003410c95e3b3c0c846d0 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 19 Oct 2020 15:20:42 +0800 Subject: [PATCH] Fix error message of multinomial op (#27946) * fix multinomial doc * fix multinomial error message * little doc change * fix Categorical class doc * optimize format of error message * fix CPU Kernel error message format * fix isinf and isnan error in WindowsOPENBLAS CI * delete inf and nan * add manual_seed in sample code * little error message change * change error message to InvalidArgument * add full point for error message and add manual_seed in CPU environment --- paddle/fluid/operators/multinomial_op.cc | 15 ++ paddle/fluid/operators/multinomial_op.cu | 9 + paddle/fluid/operators/multinomial_op.h | 29 +-- python/paddle/distribution.py | 197 ++++++++++-------- .../tests/unittests/test_multinomial_op.py | 36 ++++ python/paddle/tensor/random.py | 65 +++--- 6 files changed, 213 insertions(+), 138 deletions(-) diff --git a/paddle/fluid/operators/multinomial_op.cc b/paddle/fluid/operators/multinomial_op.cc index 94c9fc2d9742b..165d402342162 100644 --- a/paddle/fluid/operators/multinomial_op.cc +++ b/paddle/fluid/operators/multinomial_op.cc @@ -53,12 +53,27 @@ class MultinomialOp : public framework::OperatorWithKernel { auto x_dim = ctx->GetInputDim("X"); int64_t x_rank = x_dim.size(); + PADDLE_ENFORCE_GT(x_rank, 0, + platform::errors::InvalidArgument( + "The number of dimensions of the input probability " + "distribution should be > 0, but got %d.", + x_rank)); + PADDLE_ENFORCE_LE(x_rank, 2, + platform::errors::InvalidArgument( + "The number of dimensions of the input probability " + "distribution should be <= 2, but got %d.", + x_rank)); + std::vector out_dims(x_rank); for (int64_t i = 0; i < x_rank - 1; i++) { out_dims[i] = x_dim[i]; } int64_t num_samples = ctx->Attrs().Get("num_samples"); + PADDLE_ENFORCE_GT( + num_samples, 0, + platform::errors::InvalidArgument( + "The number of samples should be > 0, but got %d.", num_samples)); out_dims[x_rank - 1] = num_samples; ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); diff --git a/paddle/fluid/operators/multinomial_op.cu b/paddle/fluid/operators/multinomial_op.cu index 2762f0ce9bd46..92f7c992ed976 100644 --- a/paddle/fluid/operators/multinomial_op.cu +++ b/paddle/fluid/operators/multinomial_op.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/multinomial_op.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/transform.h" namespace paddle { @@ -31,6 +32,14 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data, T* sum_rows) { int id = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + PADDLE_ENFORCE( + in_data[id] >= 0.0, + "The input of multinomial distribution should be >= 0, but got %f.", + in_data[id]); + PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0, + "The sum of one multinomial distribution probability should " + "be > 0, but got %f.", + sum_rows[blockIdx.y]); norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; } diff --git a/paddle/fluid/operators/multinomial_op.h b/paddle/fluid/operators/multinomial_op.h index 420d2cd11e37d..14cfbd268389e 100644 --- a/paddle/fluid/operators/multinomial_op.h +++ b/paddle/fluid/operators/multinomial_op.h @@ -44,28 +44,29 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data, int64_t num_zeros = 0; for (int64_t j = 0; j < num_categories; j++) { prob_value = in_data[i * num_categories + j]; - PADDLE_ENFORCE_GE( - prob_value, 0.0, - platform::errors::OutOfRange( - "The input of multinomial distribution should be >= 0")); - PADDLE_ENFORCE_EQ((std::isinf(static_cast(prob_value)) || - std::isnan(static_cast(prob_value))), - false, platform::errors::OutOfRange( - "The input of multinomial distribution " - "shoud not be infinity or NaN")); + PADDLE_ENFORCE_GE(prob_value, 0.0, + platform::errors::InvalidArgument( + "The input of multinomial distribution " + "should be >= 0, but got %f.", + prob_value)); + probs_sum += prob_value; if (prob_value == 0) { num_zeros += 1; } cumulative_probs[j] = probs_sum; } - PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange( - "The sum of input should not be 0")); + PADDLE_ENFORCE_GT(probs_sum, 0.0, + platform::errors::InvalidArgument( + "The sum of one multinomial distribution " + "probability should be > 0, but got %f.", + probs_sum)); PADDLE_ENFORCE_EQ( (replacement || (num_categories - num_zeros >= num_samples)), true, - platform::errors::OutOfRange("When replacement is False, number of " - "samples should be less than non-zero " - "categories")); + platform::errors::InvalidArgument( + "When replacement is False, number of " + "samples should be less than non-zero " + "categories.")); for (int64_t j = 0; j < num_categories; j++) { cumulative_probs[j] /= probs_sum; diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py index 63a94a11f07bf..9133751a5309f 100644 --- a/python/paddle/distribution.py +++ b/python/paddle/distribution.py @@ -662,48 +662,54 @@ class Categorical(Distribution): Args: logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Examples: .. code-block:: python - import paddle - from paddle.distribution import Categorical + import paddle + from paddle.distribution import Categorical - x = paddle.rand([6]) - print(x.numpy()) - # [0.32564053, 0.99334985, 0.99034804, - # 0.09053693, 0.30820143, 0.19095989] - y = paddle.rand([6]) - print(y.numpy()) - # [0.6365463 , 0.7278677 , 0.90260243, - # 0.5226815 , 0.35837543, 0.13981032] + paddle.manual_seed(100) # on CPU device + x = paddle.rand([6]) + print(x.numpy()) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] - cat = Categorical(x) - cat2 = Categorical(y) + paddle.manual_seed(200) # on CPU device + y = paddle.rand([6]) + print(y.numpy()) + # [0.77663314 0.90824795 0.15685187 + # 0.04279523 0.34468332 0.7955718 ] - cat.sample([2,3]) - # [[5, 1, 1], - # [0, 1, 2]] + cat = Categorical(x) + cat2 = Categorical(y) - cat.entropy() - # [1.71887] + paddle.manual_seed(1000) # on CPU device + cat.sample([2,3]) + # [[0, 0, 5], + # [3, 4, 5]] - cat.kl_divergence(cat2) - # [0.0278455] + cat.entropy() + # [1.77528] - value = paddle.to_tensor([2,1,3]) - cat.probs(value) - # [0.341613 0.342648 0.03123] + cat.kl_divergence(cat2) + # [0.071952] - cat.log_prob(value) - # [-1.07408 -1.07105 -3.46638] + value = paddle.to_tensor([2,1,3]) + cat.probs(value) + # [0.00608027 0.108298 0.269656] + + cat.log_prob(value) + # [-5.10271 -2.22287 -1.31061] """ def __init__(self, logits, name=None): """ Args: - logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64. + logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. """ if not in_dygraph_mode(): check_type(logits, 'logits', (np.ndarray, tensor.Variable, list), @@ -727,27 +733,29 @@ def sample(self, shape): """Generate samples of the specified shape. Args: - shape (list): Shape of the generated samples. + shape (list): Shape of the generated samples. Returns: - Tensor: A tensor with prepended dimensions shape. + Tensor: A tensor with prepended dimensions shape. Examples: - .. code-block:: python + .. code-block:: python - import paddle - from paddle.distribution import Categorical + import paddle + from paddle.distribution import Categorical - x = paddle.rand([6]) - print(x.numpy()) - # [0.32564053, 0.99334985, 0.99034804, - # 0.09053693, 0.30820143, 0.19095989] + paddle.manual_seed(100) # on CPU device + x = paddle.rand([6]) + print(x.numpy()) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] - cat = Categorical(x) + cat = Categorical(x) - cat.sample([2,3]) - # [[5, 1, 1], - # [0, 1, 2]] + paddle.manual_seed(1000) # on CPU device + cat.sample([2,3]) + # [[0, 0, 5], + # [3, 4, 5]] """ name = self.name + '_sample' @@ -775,28 +783,31 @@ def kl_divergence(self, other): other (Categorical): instance of Categorical. The data type is float32. Returns: - Variable: kl-divergence between two Categorical distributions. + Tensor: kl-divergence between two Categorical distributions. Examples: - .. code-block:: python + .. code-block:: python - import paddle - from paddle.distribution import Categorical + import paddle + from paddle.distribution import Categorical + + paddle.manual_seed(100) # on CPU device + x = paddle.rand([6]) + print(x.numpy()) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] - x = paddle.rand([6]) - print(x.numpy()) - # [0.32564053, 0.99334985, 0.99034804, - # 0.09053693, 0.30820143, 0.19095989] - y = paddle.rand([6]) - print(y.numpy()) - # [0.6365463 , 0.7278677 , 0.90260243, - # 0.5226815 , 0.35837543, 0.13981032] + paddle.manual_seed(200) # on CPU device + y = paddle.rand([6]) + print(y.numpy()) + # [0.77663314 0.90824795 0.15685187 + # 0.04279523 0.34468332 0.7955718 ] - cat = Categorical(x) - cat2 = Categorical(y) + cat = Categorical(x) + cat2 = Categorical(y) - cat.kl_divergence(cat2) - # [0.0278455] + cat.kl_divergence(cat2) + # [0.071952] """ name = self.name + '_kl_divergence' @@ -823,23 +834,24 @@ def entropy(self): """Shannon entropy in nats. Returns: - Variable: Shannon entropy of Categorical distribution. The data type is float32. + Tensor: Shannon entropy of Categorical distribution. The data type is float32. Examples: - .. code-block:: python + .. code-block:: python - import paddle - from paddle.distribution import Categorical + import paddle + from paddle.distribution import Categorical - x = paddle.rand([6]) - print(x.numpy()) - # [0.32564053, 0.99334985, 0.99034804, - # 0.09053693, 0.30820143, 0.19095989] + paddle.manual_seed(100) # on CPU device + x = paddle.rand([6]) + print(x.numpy()) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] - cat = Categorical(x) + cat = Categorical(x) - cat.entropy() - # [1.71887] + cat.entropy() + # [1.77528] """ name = self.name + '_entropy' @@ -864,27 +876,28 @@ def probs(self, value): with ``logits. That is, ``value[:-1] = logits[:-1]``. Args: - value (Tensor): The input tensor represents the selected category index. + value (Tensor): The input tensor represents the selected category index. Returns: - Tensor: probability according to the category index. + Tensor: probability according to the category index. Examples: - .. code-block:: python + .. code-block:: python - import paddle - from paddle.distribution import Categorical + import paddle + from paddle.distribution import Categorical - x = paddle.rand([6]) - print(x.numpy()) - # [0.32564053, 0.99334985, 0.99034804, - # 0.09053693, 0.30820143, 0.19095989] + paddle.manual_seed(100) # on CPU device + x = paddle.rand([6]) + print(x.numpy()) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] - cat = Categorical(x) + cat = Categorical(x) - value = paddle.to_tensor([2,1,3]) - cat.probs(value) - # [0.341613 0.342648 0.03123] + value = paddle.to_tensor([2,1,3]) + cat.probs(value) + # [0.00608027 0.108298 0.269656] """ name = self.name + '_probs' @@ -929,28 +942,28 @@ def log_prob(self, value): """Log probabilities of the given category. Refer to ``probs`` method. Args: - value (Tensor): The input tensor represents the selected category index. + value (Tensor): The input tensor represents the selected category index. Returns: - Tensor: Log probability. + Tensor: Log probability. Examples: - .. code-block:: python - - import paddle - from paddle.distribution import Categorical + .. code-block:: python - x = paddle.rand([6]) - print(x.numpy()) - # [0.32564053, 0.99334985, 0.99034804, - # 0.09053693, 0.30820143, 0.19095989] + import paddle + from paddle.distribution import Categorical - cat = Categorical(x) + paddle.manual_seed(100) # on CPU device + x = paddle.rand([6]) + print(x.numpy()) + # [0.5535528 0.20714243 0.01162981 + # 0.51577556 0.36369765 0.2609165 ] - value = paddle.to_tensor([2,1,3]) + cat = Categorical(x) - cat.log_prob(value) - # [-1.07408 -1.07105 -3.46638] + value = paddle.to_tensor([2,1,3]) + cat.log_prob(value) + # [-5.10271 -2.22287 -1.31061] """ name = self.name + '_log_prob' diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py index 7cca7738efd05..db4978930e049 100644 --- a/python/paddle/fluid/tests/unittests/test_multinomial_op.py +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -17,12 +17,14 @@ import unittest import paddle import paddle.fluid as fluid +from paddle.fluid import core from op_test import OpTest import numpy as np class TestMultinomialOp(OpTest): def setUp(self): + paddle.enable_static() self.op_type = "multinomial" self.init_data() self.inputs = {"X": self.input_np} @@ -175,5 +177,39 @@ def test_alias(self): paddle.tensor.random.multinomial(x, num_samples=10, replacement=True) +class TestMultinomialError(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_num_sample(self): + def test_num_sample_less_than_0(): + x = paddle.rand([4]) + paddle.multinomial(x, num_samples=-2) + + self.assertRaises(core.EnforceNotMet, test_num_sample_less_than_0) + + def test_replacement_False(self): + def test_samples_larger_than_categories(): + x = paddle.rand([4]) + paddle.multinomial(x, num_samples=5, replacement=False) + + self.assertRaises(core.EnforceNotMet, + test_samples_larger_than_categories) + + def test_input_probs_dim(self): + def test_dim_larger_than_2(): + x = paddle.rand([2, 3, 3]) + paddle.multinomial(x) + + self.assertRaises(core.EnforceNotMet, test_dim_larger_than_2) + + def test_dim_less_than_1(): + x_np = np.random.random([]) + x = paddle.to_tensor(x_np) + paddle.multinomial(x) + + self.assertRaises(core.EnforceNotMet, test_dim_less_than_1) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index de7bb6f164ec5..eb9750bcc3957 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -57,19 +57,19 @@ def bernoulli(x, name=None): Examples: .. code-block:: python - import paddle - - paddle.disable_static() + import paddle - x = paddle.rand([2, 3]) - print(x.numpy()) - # [[0.11272584 0.3890902 0.7730957 ] - # [0.10351662 0.8510418 0.63806665]] + paddle.manual_seed(100) # on CPU device + x = paddle.rand([2,3]) + print(x.numpy()) + # [[0.5535528 0.20714243 0.01162981] + # [0.51577556 0.36369765 0.2609165 ]] - out = paddle.bernoulli(x) - print(out.numpy()) - # [[0. 0. 1.] - # [0. 0. 1.]] + paddle.manual_seed(200) # on CPU device + out = paddle.bernoulli(x) + print(out.numpy()) + # [[0. 0. 0.] + # [1. 1. 0.]] """ @@ -108,28 +108,29 @@ def multinomial(x, num_samples=1, replacement=False, name=None): Examples: .. code-block:: python - import paddle - - paddle.disable_static() - - x = paddle.rand([2,4]) - print(x.numpy()) - # [[0.7713825 0.4055941 0.433339 0.70706886] - # [0.9223313 0.8519825 0.04574518 0.16560672]] - - out1 = paddle.multinomial(x, num_samples=5, replacement=True) - print(out1.numpy()) - # [[3 3 1 1 0] - # [0 0 0 0 1]] - - # out2 = paddle.multinomial(x, num_samples=5) - # OutOfRangeError: When replacement is False, number of samples - # should be less than non-zero categories + import paddle - out3 = paddle.multinomial(x, num_samples=3) - print(out3.numpy()) - # [[0 2 3] - # [0 1 3]] + paddle.manual_seed(100) # on CPU device + x = paddle.rand([2,4]) + print(x.numpy()) + # [[0.5535528 0.20714243 0.01162981 0.51577556] + # [0.36369765 0.2609165 0.18905126 0.5621971 ]] + + paddle.manual_seed(200) # on CPU device + out1 = paddle.multinomial(x, num_samples=5, replacement=True) + print(out1.numpy()) + # [[3 3 0 0 0] + # [3 3 3 1 0]] + + # out2 = paddle.multinomial(x, num_samples=5) + # InvalidArgumentError: When replacement is False, number of samples + # should be less than non-zero categories + + paddle.manual_seed(300) # on CPU device + out3 = paddle.multinomial(x, num_samples=3) + print(out3.numpy()) + # [[3 0 1] + # [3 1 0]] """