Skip to content

Commit 25ea9e6

Browse files
committed
add more tests
1 parent b8010bb commit 25ea9e6

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

python/paddle/tensor/creation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def full_like(
10831083
(numbers.Number, str, core.eager.Tensor, Variable, paddle.pir.Value),
10841084
):
10851085
raise TypeError(
1086-
f"The fill_value should be scalar or Tensor, but received {type(fill_value)}."
1086+
f"The fill_value should be int, float, bool, complex, np.number, string numeric value or Tensor, but received {type(fill_value)}."
10871087
)
10881088

10891089
if dtype is None:
@@ -1592,7 +1592,7 @@ def full(
15921592
(numbers.Number, str, core.eager.Tensor, Variable, paddle.pir.Value),
15931593
):
15941594
raise TypeError(
1595-
f"The fill_value should be scalar or Tensor, but received {type(fill_value)}."
1595+
f"The fill_value should be int, float, bool, complex, np.number, string numeric values or Tensor, but received {type(fill_value)}."
15961596
)
15971597

15981598
if dtype is None:

test/legacy_test/test_full_like_op.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,29 @@ def test_errors(self):
118118

119119
def test_fill_value_errors(self):
120120
with dygraph_guard():
121-
# The fill_value must be one of [int, float, bool, complex, Tensor].
121+
# The fill_value must be one of [int, float, bool, complex, Tensor, np.number].
122122
self.assertRaises(
123123
TypeError,
124-
paddle.full,
125-
shape=[1],
126-
dtype="float32",
124+
paddle.full_like,
125+
x=paddle.to_tensor([1.0, 2.0]),
127126
fill_value=np.array([1.0], dtype=np.float32),
127+
dtype="float32",
128128
)
129129

130130
self.assertRaises(
131131
TypeError,
132-
paddle.full,
133-
shape=[1],
134-
dtype="float32",
132+
paddle.full_like,
133+
x=paddle.to_tensor([1.0, 2.0]),
135134
fill_value=[1.0],
135+
dtype="float32",
136+
)
137+
138+
self.assertRaises(
139+
TypeError,
140+
paddle.full_like,
141+
x=paddle.to_tensor([1.0, 2.0]),
142+
fill_value=np.bool_(True),
143+
dtype="bool",
136144
)
137145

138146

@@ -219,6 +227,16 @@ def test_skip_data_transform(self):
219227
paddle.enable_static()
220228

221229

230+
class TestFullLikeOp5(TestFullLikeOp1):
231+
def init_data(self):
232+
self.fill_value = True
233+
self.shape = [10, 10]
234+
self.dtype = np.bool
235+
236+
def if_enable_cinn(self):
237+
pass
238+
239+
222240
class TestFullLikeFP16Op(TestFullLikeOp1):
223241
def init_data(self):
224242
self.fill_value = 6666

test/legacy_test/test_full_op.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def test_shape_tensor_list_dtype():
406406

407407
def test_fill_value_errors(self):
408408
with dygraph_guard():
409-
# The fill_value must be one of [int, float, bool, complex, Tensor].
409+
# The fill_value must be one of [int, float, bool, complex, np.number, Tensor].
410410
self.assertRaises(
411411
TypeError,
412412
paddle.full,
@@ -423,6 +423,14 @@ def test_fill_value_errors(self):
423423
fill_value=[1.0],
424424
)
425425

426+
self.assertRaises(
427+
TypeError,
428+
paddle.full,
429+
shape=[1],
430+
dtype="bool",
431+
fill_value=np.bool_(True),
432+
)
433+
426434

427435
if __name__ == "__main__":
428436
unittest.main()

0 commit comments

Comments
 (0)