Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import builtins
import math
import numbers
import re
import warnings
from typing import TYPE_CHECKING, overload
Expand Down Expand Up @@ -1043,7 +1044,7 @@ def get_slice(

def full_like(
x: paddle.Tensor,
fill_value: bool | float,
fill_value: Numeric | str,
dtype: DTypeLike | None = None,
*,
device: PlaceLike | None = None,
Expand All @@ -1057,9 +1058,10 @@ def full_like(

Args:
x(Tensor): The input tensor which specifies shape and data type. The data type can be bool, float16, float32, float64, int32, int64.
fill_value(bool|float|int): The value to fill the tensor with. Note: this value shouldn't exceed the range of the output data type.
fill_value(Scalar|Tensor): The value to fill the tensor with. Note: this value shouldn't exceed the range of the output data type.
If ``fill_value`` is an Tensor, it should be an 0-D Tensor which represents a scalar.
dtype(np.dtype|str, optional): The data type of output. The data type can be one
of bool, float16, float32, float64, int32, int64. The default value is None, which means the output
of bool, float16, float32, float64, int32, int64, complex64, complex128. The default value is None, which means the output
data type is the same as input.
device(PlaceLike|None, optional): The desired device of returned tensor.
if None, uses the current device for the default tensor type (see paddle.device.set_device()).
Expand All @@ -1081,6 +1083,15 @@ def full_like(
[[2. 2. 2.]
[2. 2. 2.]]
"""
# Include str type check to handle string numeric values like "0.5" that occur in CI tests.
# The compatible method for fliud operators, may be it can be removed in the future.
if not isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不在支持bool类型,是否会对套件库的兼容性产生影响?

Copy link
Contributor Author

@ooooo-create ooooo-create Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python 的 bool 类型是支持的~
图片

numpy.bool 目前本身是不支持的,现在的判断是会拦截的,应该不会有影响,就是报错位置提前了
图片

fill_value,
(numbers.Number, str, core.eager.Tensor, Variable, paddle.pir.Value),
):
raise TypeError(
f"The fill_value should be int, float, bool, complex, np.number, string numeric value or Tensor, but received {type(fill_value)}."
)

if dtype is None:
dtype = x.dtype
Expand Down Expand Up @@ -1635,7 +1646,7 @@ def _check_attr(attr, message):
@ParamAliasDecorator({"shape": ["size"]})
def full(
shape: ShapeLike,
fill_value: bool | float | paddle.Tensor,
fill_value: Numeric | str,
dtype: DTypeLike | None = None,
*,
out: paddle.Tensor | None = None,
Expand All @@ -1656,10 +1667,10 @@ def full(
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
Alias: ``size``.
fill_value(bool|float|int|Tensor): The constant value used to initialize the Tensor to be created.
fill_value(Scalar|Tensor): The constant value used to initialize the Tensor to be created.
If ``fill_value`` is an Tensor, it should be an 0-D Tensor which represents a scalar.
dtype(np.dtype|str, optional): Data type of the output Tensor
which can be float16, float32, float64, int32, int64, if dtype is `None`, the data
which can be float16, float32, float64, int32, int64, complex64, complex128. If dtype is `None`, the data
type of created Tensor is `float32`.
out(Tensor, optional): The output tensor.
device(PlaceLike|None, optional): The desired device of returned tensor.
Expand Down Expand Up @@ -1707,6 +1718,15 @@ def full(
[2. 2.]
[2. 2.]]
"""
# Include str type check to handle string numeric values like "0.5" that occur in CI tests.
# The compatible method for fliud operators, may be it can be removed in the future.
if not isinstance(
fill_value,
(numbers.Number, str, core.eager.Tensor, Variable, paddle.pir.Value),
):
raise TypeError(
f"The fill_value should be int, float, bool, complex, np.number, string numeric values or Tensor, but received {type(fill_value)}."
)

if dtype is None:
if isinstance(fill_value, (bool)):
Expand Down
83 changes: 81 additions & 2 deletions test/legacy_test/test_full_like_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
from op_test import OpTest, convert_float_to_uint16
from utils import dygraph_guard, static_guard

import paddle
import paddle.framework.dtype as dtypes
Expand All @@ -41,7 +42,7 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None):
return paddle.full_like(x, value, tmp_dtype, name=name)


class TestFullOp(unittest.TestCase):
class TestFullLikeOp(unittest.TestCase):
"""Test fill_any_like op(whose API is full_like) for attr out."""

def test_attr_tensor_API(self):
Expand Down Expand Up @@ -94,7 +95,8 @@ def test_full_like_fill_inf(self):
paddle.enable_static()


class TestFullOpError(unittest.TestCase):
class TestFullLikeOpError(unittest.TestCase):

def test_errors(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand All @@ -114,6 +116,33 @@ def test_errors(self):
dtype='uint4',
)

def test_fill_value_errors(self):
with dygraph_guard():
# The fill_value must be one of [int, float, bool, complex, Tensor, np.number].
self.assertRaises(
TypeError,
paddle.full_like,
x=paddle.to_tensor([1.0, 2.0]),
fill_value=np.array([1.0], dtype=np.float32),
dtype="float32",
)

self.assertRaises(
TypeError,
paddle.full_like,
x=paddle.to_tensor([1.0, 2.0]),
fill_value=[1.0],
dtype="float32",
)

self.assertRaises(
TypeError,
paddle.full_like,
x=paddle.to_tensor([1.0, 2.0]),
fill_value=np.bool_(True),
dtype="bool",
)


class TestFullLikeOp1(OpTest):
# test basic
Expand Down Expand Up @@ -198,6 +227,16 @@ def test_skip_data_transform(self):
paddle.enable_static()


class TestFullLikeOp5(TestFullLikeOp1):
def init_data(self):
self.fill_value = True
self.shape = [10, 10]
self.dtype = np.bool

def if_enable_cinn(self):
pass


class TestFullLikeFP16Op(TestFullLikeOp1):
def init_data(self):
self.fill_value = 6666
Expand Down Expand Up @@ -268,5 +307,45 @@ def test_full_like_kernel_gpu_zero_size(self):
paddle.enable_static()


class TestFullLikeWithTensorValue(unittest.TestCase):
def test_dygraph_api(self):
with dygraph_guard():
base_np = np.array([[1, 2], [3, 4]], dtype=np.float32)
value_np = np.array([5.0], dtype=np.float32)
base_tensor = paddle.to_tensor(base_np)
value_tensor = paddle.to_tensor(value_np)
result = paddle.full_like(base_tensor, value_tensor)
expected = np.full_like(base_np, value_np)
np.testing.assert_array_equal(result.numpy(), expected)

def test_static_api(self):
with static_guard():
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
base_tensor = paddle.static.data(
name='base_tensor', dtype='float32', shape=[2, 2]
)
value_tensor = paddle.static.data(
name='value_tensor', dtype='float32', shape=[1]
)
result = paddle.full_like(base_tensor, value_tensor)

place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

base_np = np.array([[1, 2], [3, 4]], dtype=np.float32)
value_np = np.array([5.0], dtype=np.float32)

res = exe.run(
train_program,
feed={'base_tensor': base_np, 'value_tensor': value_np},
fetch_list=[result],
)

expected = np.full_like(base_np, value_np)
np.testing.assert_array_equal(res[0], expected)


if __name__ == "__main__":
unittest.main()
28 changes: 28 additions & 0 deletions test/legacy_test/test_full_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import dygraph_guard

import paddle
from paddle import base
Expand Down Expand Up @@ -444,6 +445,33 @@ def test_shape_tensor_list_dtype():
self.assertRaises(TypeError, test_shape_tensor_list_dtype)
paddle.disable_static()

def test_fill_value_errors(self):
with dygraph_guard():
# The fill_value must be one of [int, float, bool, complex, np.number, Tensor].
self.assertRaises(
TypeError,
paddle.full,
shape=[1],
dtype="float32",
fill_value=np.array([1.0], dtype=np.float32),
)

self.assertRaises(
TypeError,
paddle.full,
shape=[1],
dtype="float32",
fill_value=[1.0],
)

self.assertRaises(
TypeError,
paddle.full,
shape=[1],
dtype="bool",
fill_value=np.bool_(True),
)


if __name__ == "__main__":
unittest.main()