Skip to content

Commit

Permalink
delete probs' default value/ fix sample/rsample's param
Browse files Browse the repository at this point in the history
  • Loading branch information
dasenCoding committed Apr 21, 2023
1 parent 7439580 commit c3cdd69
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/paddle/distribution/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Geometric(distribution.Distribution):
# [1.41421354])
"""

def __init__(self, probs=None):
def __init__(self, probs):
if isinstance(probs, (numbers.Real, paddle.Tensor, framework.Variable)):
if isinstance(probs, numbers.Real):
probs = paddle.full(
Expand Down Expand Up @@ -187,7 +187,7 @@ def sample(self, shape=()):
"""Sample from Geometric distribution with sample shape.
Args:
shape (Sequence[int], optional): Sample shape.
shape (tuple(int)): Sample shape.
Returns:
Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Expand All @@ -214,7 +214,7 @@ def rsample(self, shape=()):
"""Generate samples of the specified shape.
Args:
shape(tuple): The shape of generated samples.
shape(tuple(int)): The shape of generated samples.
Returns:
Tensor: A sample tensor that fits the Geometric distribution.
Expand Down

0 comments on commit c3cdd69

Please sign in to comment.