Skip to content

Commit

Permalink
fix coverage prob
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Sep 28, 2020
1 parent 7464df5 commit 3cc41b2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
36 changes: 36 additions & 0 deletions python/paddle/fluid/tests/unittests/test_multinomial_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest
import paddle
import paddle.fluid as fluid
from op_test import OpTest
import numpy as np

Expand Down Expand Up @@ -159,6 +160,32 @@ def test_dygraph3(self):
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()

def test_static(self):
paddle.enable_static()
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
x = fluid.data('x', shape=[4], dtype='float32')
out = paddle.multinomial(x, num_samples=100000, replacement=True)

place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)

exe.run(startup_program)
x_np = np.random.rand(4).astype('float32')
out = exe.run(train_program, feed={'x': x_np}, fetch_list=[out])

sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()

prob = x_np / x_np.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))

"""
def test_replacement_error(self):
def test_error():
Expand All @@ -170,5 +197,14 @@ def test_error():
"""


class TestMultinomialAlias(unittest.TestCase):
def test_alias(self):
paddle.disable_static()
x = paddle.rand([4])
paddle.multinomial(x, num_samples=10, replacement=True)
paddle.tensor.multinomial(x, num_samples=10, replacement=True)
paddle.tensor.random.multinomial(x, num_samples=10, replacement=True)


if __name__ == "__main__":
unittest.main()
20 changes: 10 additions & 10 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define random functions
# TODO: define random functions

from ..fluid import core
from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_
Expand Down Expand Up @@ -40,18 +40,18 @@ def bernoulli(x, name=None):
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
The input ``x`` is a tensor with probabilities for generating the random binary number.
Each element in ``x`` should be in [0, 1], and the out is generated by:
.. math::
out_i ~ Bernoulli (x_i)
Args:
x(Tensor): A tensor with probabilities for generating the random binary number. The data type
x(Tensor): A tensor with probabilities for generating the random binary number. The data type
should be float32, float64.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Tensor: A Tensor filled with random binary number with the same shape and dtype as ``x``.
Examples:
Expand Down Expand Up @@ -80,7 +80,7 @@ def bernoulli(x, name=None):

helper = LayerHelper("randint", **locals())
out = helper.create_variable_for_type_inference(
dtype=x.dtype) # maybe set out to int32 ?
dtype=x.dtype) # maybe set out to int32 ?
helper.append_op(
type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={})
return out
Expand Down Expand Up @@ -122,7 +122,7 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
# [[3 3 1 1 0]
# [0 0 0 0 1]]
out2 = paddle.multinomial(x, num_samples=5)
# out2 = paddle.multinomial(x, num_samples=5)
# OutOfRangeError: When replacement is False, number of samples
# should be less than non-zero categories
Expand Down Expand Up @@ -176,7 +176,7 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
Returns:
Tensor: A Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``.
distribution, with ``shape`` and ``dtype``.
"""
op_type_for_check = 'gaussian/standard_normal/randn/normal'
seed = 0
Expand Down Expand Up @@ -417,7 +417,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
Expand Down Expand Up @@ -505,7 +505,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Tensor: A Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
Expand Down Expand Up @@ -615,7 +615,7 @@ def randperm(n, dtype="int64", name=None):
out2 = paddle.randperm(7, 'int32')
# [1, 6, 2, 0, 4, 3, 5] # random
"""
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
Expand Down

0 comments on commit 3cc41b2

Please sign in to comment.