Skip to content

Commit

Permalink
Update search.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BrilliantYuKaimin committed Jun 2, 2022
1 parent 2483b82 commit c61ac9e
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,13 @@ def where(condition, x=None, y=None, name=None):
\begin{cases}
x_i, & \text{if} \ condition_i \ \text{is} \ True \\
y_i, & \text{if} \ condition_i \ \text{is} \ False \\
\end{cases}
\end{cases}.
Notes:
``paddle.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``.
``numpy.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``, please refer to :ref:`api_tensor_search_nonzero`.
Args:
condition (Tensor): The condition to choose x or y. When True(nonzero), yield x, otherwise yield y.
condition (Tensor): The condition to choose x or y. When True (nonzero), yield x, otherwise yield y.
x (Tensor|scalar, optional): A Tensor or scalar to choose when the condition is True with data type of float32, float64, int32 or int64. Either both or neither of x and y should be given.
y (Tensor|scalar, optional): A Tensor or scalar to choose when the condition is False with data type of float32, float64, int32 or int64. Either both or neither of x and y should be given.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Expand All @@ -583,22 +583,22 @@ def where(condition, x=None, y=None, name=None):
Examples:
.. code-block:: python
:name:where-example
:name:where-example
import paddle
import paddle
x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
out = paddle.where(x>1, x, y)
x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
out = paddle.where(x>1, x, y)
print(out)
#out: [1.0, 1.0, 3.2, 1.2]
print(out)
#out: [1.0, 1.0, 3.2, 1.2]
out = paddle.where(x>1)
print(out)
#out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [[2],
# [3]]),)
out = paddle.where(x>1)
print(out)
#out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [[2],
# [3]]),)
"""
if np.isscalar(x):
x = paddle.full([1], x, np.array([x]).dtype.name)
Expand Down

1 comment on commit c61ac9e

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.