Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,34 @@
to_dlpack,
)


class _TensorMethodOrModule:
def __init__(self):
import paddle.tensor as tensor_module

from .tensor.creation import tensor as tensor_api

self.module = tensor_module
self.method = tensor_api

def __call__(self, *args, **kwargs):
return self.method(*args, **kwargs)

def __getattr__(self, name):
return getattr(self.module, name)

def __repr__(self):
return repr(self.method)

def __str__(self):
return str(self.method)

def __dir__(self):
return dir(self.module)


tensor = _TensorMethodOrModule() # noqa: F811

# CINN has to set a flag to include a lib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_TensorMethodOrModule 这个class的其他一些方法也需要实现下,结合python原生方法,使体验上更好。比如type、print、repr应该看起来是一个方法,而dir() 看起来是模块

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_tensor_api -> tensor_api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if is_compiled_with_cinn():
import os
Expand Down
157 changes: 127 additions & 30 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

__all__ = []

_warned_in_to_tensor = False
_warned_in_tensor = False


def _complex_to_real_dtype(dtype: DTypeLike) -> DTypeLike:
Expand Down Expand Up @@ -877,7 +877,129 @@ def _to_tensor_static(
return output


@ParamAliasDecorator({"place": ["device"]})
def tensor(
data: TensorLike | NestedNumericSequence,
dtype: DTypeLike | None = None,
device: PlaceLike | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> paddle.Tensor:
r"""
Constructs a ``paddle.Tensor`` from ``data`` ,
which can be scalar, tuple, list, numpy\.ndarray, paddle\.Tensor.

If the ``data`` is already a Tensor, copy will be performed and return a new tensor.
If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly.

.. code-block:: text

We use the dtype conversion rules following this:
Keep dtype
np.number ───────────► paddle.Tensor
(0-D Tensor)
default_dtype
Python Number ───────────────► paddle.Tensor
(0-D Tensor)
Keep dtype
np.ndarray ───────────► paddle.Tensor

Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
device(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``device`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
requires_grad(bool, optional): Whether to block the gradient propagation of Autograd. Default: False.
pin_memory(bool, optional): If set, return tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False

Returns:
Tensor: A Tensor constructed from ``data`` .

Examples:
.. code-block:: python

>>> import paddle

>>> type(paddle.tensor(1))
<class 'paddle.Tensor'>

>>> paddle.tensor(1)
Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
1)

>>> x = paddle.tensor(1, requires_grad=True)
>>> print(x)
Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=False,
1)

>>> paddle.tensor(x) # A new tensor will be created with default stop_gradient=True
Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
1)

>>> paddle.tensor([[0.1, 0.2], [0.3, 0.4]], device=paddle.CPUPlace(), requires_grad=True)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.10000000, 0.20000000],
[0.30000001, 0.40000001]])

>>> type(paddle.tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64'))
<class 'paddle.Tensor'>

>>> paddle.tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64')
Tensor(shape=[2, 2], dtype=complex64, place=Place(cpu), stop_gradient=True,
[[(1+1j), (2+0j)],
[(3+2j), (4+0j)]])
"""
if isinstance(device, str) and "cuda" in device:
device = device.replace("cuda", "gpu")
stop_gradient = not requires_grad
place = _get_paddle_place(device)
if place is None:
place = _current_expected_place_()
if pin_memory and not isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是not_instance,下面是isinstance?

paddle.tensor写单测 测试一下吧,虽然 to_tensor会调用 tensor ,但这两个API还是有些不同,比如 pin_memory、device这些参数

Copy link
Contributor Author

@zeroRains zeroRains Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里not isinstance是想确定place是不是已经是PinnedPlace了,是的话就没必要继续判断了,下面的isinstance是确定当前的Place是GPU或XPU。

好的我补充一下单测

place, (core.CUDAPinnedPlace, core.XPUPinnedPlace)
):
if isinstance(place, core.CUDAPlace):
place = core.CUDAPinnedPlace()
elif isinstance(place, core.XPUPlace):
place = core.XPUPinnedPlace()
else:
raise RuntimeError(f"Pinning memory is not supported for {place}.")

if in_dynamic_mode():
is_tensor = paddle.is_tensor(data)
if not is_tensor and hasattr(data, "__cuda_array_interface__"):
if not core.is_compiled_with_cuda():
raise RuntimeError(
"PaddlePaddle is not compiled with CUDA, but trying to create a Tensor from a CUDA array."
)
tensor = core.tensor_from_cuda_array_interface(data)
if pin_memory:
tensor = tensor.pin_memory()
else:
if is_tensor:
global _warned_in_tensor
if not _warned_in_tensor:
warnings.warn(
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach(), "
"rather than paddle.to_tensor(sourceTensor).",
stacklevel=2,
)
_warned_in_tensor = True
tensor = _to_tensor_non_static(data, dtype, place, stop_gradient)
return tensor
# call assign for static graph
else:
re_exp = re.compile(r'[(](.+?)[)]', re.DOTALL)
place_str = re.findall(re_exp, str(place))[0]
with paddle.static.device_guard(place_str):
tensor = _to_tensor_static(data, dtype, stop_gradient)
return tensor


def to_tensor(
data: TensorLike | NestedNumericSequence,
dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -957,34 +1079,9 @@ def to_tensor(
[[(1+1j), (2+0j)],
[(3+2j), (4+0j)]])
"""
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place_()
if in_dynamic_mode():
is_tensor = paddle.is_tensor(data)
if not is_tensor and hasattr(data, "__cuda_array_interface__"):
if not core.is_compiled_with_cuda():
raise RuntimeError(
"PaddlePaddle is not compiled with CUDA, but trying to create a Tensor from a CUDA array."
)
return core.tensor_from_cuda_array_interface(data)
if is_tensor:
global _warned_in_to_tensor
if not _warned_in_to_tensor:
warnings.warn(
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach(), "
"rather than paddle.to_tensor(sourceTensor).",
stacklevel=2,
)
_warned_in_to_tensor = True
return _to_tensor_non_static(data, dtype, place, stop_gradient)

# call assign for static graph
else:
re_exp = re.compile(r'[(](.+?)[)]', re.DOTALL)
place_str = re.findall(re_exp, str(place))[0]
with paddle.static.device_guard(place_str):
return _to_tensor_static(data, dtype, stop_gradient)
return tensor(
data, dtype=dtype, device=place, requires_grad=not stop_gradient
)


class MmapStorage(paddle.base.core.MmapStorage):
Expand Down
111 changes: 73 additions & 38 deletions test/legacy_test/test_eager_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,51 +377,86 @@ def test_to_tensor_attributes(self):
self.assertEqual(var.dtype, paddle.float32)
self.assertEqual(var.type, core.VarDesc.VarType.DENSE_TENSOR)

def test_to_tensor_param_alias(self):
"""Test paddle.to_tensor parameter mapping ("place": ["device"])."""
# 1. Test equivalence of place and device parameters
tensor_place = paddle.to_tensor(self.array, place=paddle.CPUPlace())
tensor_device = paddle.to_tensor(self.array, device=paddle.CPUPlace())
def test_tensor_pin_memory_and_device(self):
if core.is_compiled_with_cuda():
tensor_res = paddle.tensor(
self.array, device="gpu", pin_memory=True
)
self.assertEqual(tensor_res.place, core.CUDAPinnedPlace())

np.testing.assert_array_equal(
tensor_device.numpy(), tensor_place.numpy()
)
self.assertEqual(tensor_device.place, tensor_place.place)

# 2. Test conflict between place and device (should raise KeyError)
with self.assertRaises(ValueError) as context:
paddle.to_tensor(
self.array,
place=paddle.CPUPlace(),
device=paddle.CPUPlace(), # Conflict
tensor_cuda = paddle.tensor(self.array, device="cuda:0")
self.assertEqual(tensor_cuda.place, paddle.CUDAPlace(0))

tensor_pin = paddle.tensor(self.array, device="gpu_pinned")
self.assertEqual(tensor_pin.place, core.CUDAPinnedPlace())

if core.is_compiled_with_xpu():
tensor_res = paddle.tensor(
self.array, device="xpu", pin_memory=True
)
self.assertEqual(tensor_res.place, core.XPUPinnedPlace())

tensor_pin = paddle.tensor(self.array, device="xpu_pinned")
self.assertEqual(tensor_pin.place, core.XPUPinnedPlace())

with self.assertRaises(RuntimeError) as context:
paddle.tensor(
self.array, device="cpu", pin_memory=True # no support
)
self.assertIn(
"Cannot specify both 'place' and its alias 'device'",
"Pinning memory is not supported",
str(context.exception),
)

# 3. Test dtype and stop_gradient consistency
tensor1 = paddle.to_tensor(
self.array, dtype="float32", device=paddle.CPUPlace()
def test_tensor_and_to_tensor(self):
"""
test tensor equal to to_tensor
"""
tensor_res = paddle.tensor(
self.array, dtype="float32", device="cpu", requires_grad=True
)
tensor2 = paddle.to_tensor(
self.array, dtype="float32", place=paddle.CPUPlace()
tensor_target = paddle.to_tensor(
self.array, dtype="float32", place="cpu", stop_gradient=False
)

self.assertEqual(tensor1.dtype, tensor2.dtype)
self.assertEqual(tensor1.dtype, paddle.float32)
self.assertTrue(tensor1.stop_gradient)
self.assertEqual(tensor1.stop_gradient, tensor2.stop_gradient)

# 4. Test cross-device compatibility (CPU/GPU)
for device in [paddle.CPUPlace()] + (
[paddle.CUDAPlace(0)] if core.is_compiled_with_cuda() else []
):
tensor_device = paddle.to_tensor(self.array, device=device)
tensor_place = paddle.to_tensor(self.array, place=device)

self.assertEqual(tensor_device.place, tensor_place.place)
self.assertEqual(tensor_device.place, device)
np.testing.assert_array_equal(tensor_res.numpy(), tensor_target.numpy())
self.assertEqual(tensor_res.place, tensor_target.place)
self.assertEqual(tensor_res.place, core.CPUPlace())
self.assertEqual(tensor_res.dtype, tensor_target.dtype)
self.assertEqual(tensor_res.dtype, paddle.float32)
self.assertEqual(tensor_res.stop_gradient, tensor_target.stop_gradient)
self.assertEqual(tensor_res.stop_gradient, False)

def test_tensor_module(self):
"""
test paddle.tensor usable as an API and a module
"""
tensor_api = paddle.tensor(self.array, dtype="float32")
tensor_module = paddle.tensor.creation.tensor(
self.array, dtype="float32"
)
np.testing.assert_array_equal(tensor_api.numpy(), tensor_module.numpy())
self.assertEqual(tensor_api.place, tensor_module.place)
self.assertEqual(tensor_api.dtype, tensor_module.dtype)
self.assertEqual(tensor_api.stop_gradient, tensor_module.stop_gradient)

def test_tensor_method_or_module(self):
"""
test the class method
"""
# __rerp__
ori_repr = repr(paddle.tensor.creation.tensor)
now_repr = repr(paddle.tensor)
self.assertEqual(ori_repr, now_repr)

# __str__
ori_str = str(paddle.tensor.creation.tensor)
now_str = str(paddle.tensor)
self.assertEqual(ori_str, now_str)

# __dir__
api_dir = dir(paddle.tensor.creation.tensor)
module_dir = dir(paddle.tensor)
self.assertGreater(len(module_dir), len(api_dir))

def test_list_to_tensor(self):
array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]]
Expand Down Expand Up @@ -1348,7 +1383,7 @@ def test_to_tensor_from___cuda_array_interface__(self):
):
x = paddle.to_tensor([1, 2, 3])
paddle.to_tensor(x)
flag = paddle.tensor.creation._warned_in_to_tensor
flag = paddle.tensor.creation._warned_in_tensor
self.assertTrue(flag)

def test_dlpack_device(self):
Expand Down