Skip to content

Commit

Permalink
add type check
Browse files Browse the repository at this point in the history
  • Loading branch information
yeliang2258 committed Sep 14, 2021
1 parent b32e913 commit e65a65e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/paddle/fluid/tests/unittests/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def test_dynamic_api_float(self):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_assert(self):
def test_dynamic_api_string(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, "1.0")
paddle.enable_static()

self.assertRaises(TypeError, test_dynamic_api_string)

def test_dynamic_api_bool(self):
if self.op_type == "equal":
paddle.disable_static()
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def equal(x, y, name=None):
result1 = paddle.equal(x, y)
print(result1) # result1 = [True False False]
"""
if not isinstance(y, (int, bool, float, Variable)):
raise TypeError(
"Type of input args must be float, bool, int or Tensor, but received type {}".
format(type(y)))
if not isinstance(y, Variable):
y = full(shape=[1], dtype=x.dtype, fill_value=y)

Expand Down

1 comment on commit e65a65e

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.