diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b00d3c1f8a7443..3ebe1ebb0fdddc 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -641,6 +641,34 @@ to_dlpack, ) + +class _TensorMethodOrModule: + def __init__(self): + import paddle.tensor as tensor_module + + from .tensor.creation import tensor as tensor_api + + self.module = tensor_module + self.method = tensor_api + + def __call__(self, *args, **kwargs): + return self.method(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self.module, name) + + def __repr__(self): + return repr(self.method) + + def __str__(self): + return str(self.method) + + def __dir__(self): + return dir(self.module) + + +tensor = _TensorMethodOrModule() # noqa: F811 + # CINN has to set a flag to include a lib if is_compiled_with_cinn(): import os diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b7dfff2198b8b0..a22a0742ab87ec 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -63,7 +63,7 @@ __all__ = [] -_warned_in_to_tensor = False +_warned_in_tensor = False def _complex_to_real_dtype(dtype: DTypeLike) -> DTypeLike: @@ -877,7 +877,129 @@ def _to_tensor_static( return output -@ParamAliasDecorator({"place": ["device"]}) +def tensor( + data: TensorLike | NestedNumericSequence, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> paddle.Tensor: + r""" + Constructs a ``paddle.Tensor`` from ``data`` , + which can be scalar, tuple, list, numpy\.ndarray, paddle\.Tensor. + + If the ``data`` is already a Tensor, copy will be performed and return a new tensor. + If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly. + + .. code-block:: text + + We use the dtype conversion rules following this: + Keep dtype + np.number ───────────► paddle.Tensor + (0-D Tensor) + default_dtype + Python Number ───────────────► paddle.Tensor + (0-D Tensor) + Keep dtype + np.ndarray ───────────► paddle.Tensor + + Args: + data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor. + Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor. + dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' , + 'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8', + 'complex64' , 'complex128'. Default: None, infers dtype from ``data`` + except for python float number which gets dtype from ``get_default_type`` . + device(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be + CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``device`` is + string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs. + requires_grad(bool, optional): Whether to block the gradient propagation of Autograd. Default: False. + pin_memory(bool, optional): If set, return tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False + + Returns: + Tensor: A Tensor constructed from ``data`` . + + Examples: + .. code-block:: python + + >>> import paddle + + >>> type(paddle.tensor(1)) + + + >>> paddle.tensor(1) + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 1) + + >>> x = paddle.tensor(1, requires_grad=True) + >>> print(x) + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=False, + 1) + + >>> paddle.tensor(x) # A new tensor will be created with default stop_gradient=True + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 1) + + >>> paddle.tensor([[0.1, 0.2], [0.3, 0.4]], device=paddle.CPUPlace(), requires_grad=True) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False, + [[0.10000000, 0.20000000], + [0.30000001, 0.40000001]]) + + >>> type(paddle.tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64')) + + + >>> paddle.tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64') + Tensor(shape=[2, 2], dtype=complex64, place=Place(cpu), stop_gradient=True, + [[(1+1j), (2+0j)], + [(3+2j), (4+0j)]]) + """ + if isinstance(device, str) and "cuda" in device: + device = device.replace("cuda", "gpu") + stop_gradient = not requires_grad + place = _get_paddle_place(device) + if place is None: + place = _current_expected_place_() + if pin_memory and not isinstance( + place, (core.CUDAPinnedPlace, core.XPUPinnedPlace) + ): + if isinstance(place, core.CUDAPlace): + place = core.CUDAPinnedPlace() + elif isinstance(place, core.XPUPlace): + place = core.XPUPinnedPlace() + else: + raise RuntimeError(f"Pinning memory is not supported for {place}.") + + if in_dynamic_mode(): + is_tensor = paddle.is_tensor(data) + if not is_tensor and hasattr(data, "__cuda_array_interface__"): + if not core.is_compiled_with_cuda(): + raise RuntimeError( + "PaddlePaddle is not compiled with CUDA, but trying to create a Tensor from a CUDA array." + ) + tensor = core.tensor_from_cuda_array_interface(data) + if pin_memory: + tensor = tensor.pin_memory() + else: + if is_tensor: + global _warned_in_tensor + if not _warned_in_tensor: + warnings.warn( + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach(), " + "rather than paddle.to_tensor(sourceTensor).", + stacklevel=2, + ) + _warned_in_tensor = True + tensor = _to_tensor_non_static(data, dtype, place, stop_gradient) + return tensor + # call assign for static graph + else: + re_exp = re.compile(r'[(](.+?)[)]', re.DOTALL) + place_str = re.findall(re_exp, str(place))[0] + with paddle.static.device_guard(place_str): + tensor = _to_tensor_static(data, dtype, stop_gradient) + return tensor + + def to_tensor( data: TensorLike | NestedNumericSequence, dtype: DTypeLike | None = None, @@ -957,34 +1079,9 @@ def to_tensor( [[(1+1j), (2+0j)], [(3+2j), (4+0j)]]) """ - place = _get_paddle_place(place) - if place is None: - place = _current_expected_place_() - if in_dynamic_mode(): - is_tensor = paddle.is_tensor(data) - if not is_tensor and hasattr(data, "__cuda_array_interface__"): - if not core.is_compiled_with_cuda(): - raise RuntimeError( - "PaddlePaddle is not compiled with CUDA, but trying to create a Tensor from a CUDA array." - ) - return core.tensor_from_cuda_array_interface(data) - if is_tensor: - global _warned_in_to_tensor - if not _warned_in_to_tensor: - warnings.warn( - "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach(), " - "rather than paddle.to_tensor(sourceTensor).", - stacklevel=2, - ) - _warned_in_to_tensor = True - return _to_tensor_non_static(data, dtype, place, stop_gradient) - - # call assign for static graph - else: - re_exp = re.compile(r'[(](.+?)[)]', re.DOTALL) - place_str = re.findall(re_exp, str(place))[0] - with paddle.static.device_guard(place_str): - return _to_tensor_static(data, dtype, stop_gradient) + return tensor( + data, dtype=dtype, device=place, requires_grad=not stop_gradient + ) class MmapStorage(paddle.base.core.MmapStorage): diff --git a/test/legacy_test/test_eager_tensor.py b/test/legacy_test/test_eager_tensor.py index 8b2ce5991034fd..8768de64169d98 100644 --- a/test/legacy_test/test_eager_tensor.py +++ b/test/legacy_test/test_eager_tensor.py @@ -377,51 +377,86 @@ def test_to_tensor_attributes(self): self.assertEqual(var.dtype, paddle.float32) self.assertEqual(var.type, core.VarDesc.VarType.DENSE_TENSOR) - def test_to_tensor_param_alias(self): - """Test paddle.to_tensor parameter mapping ("place": ["device"]).""" - # 1. Test equivalence of place and device parameters - tensor_place = paddle.to_tensor(self.array, place=paddle.CPUPlace()) - tensor_device = paddle.to_tensor(self.array, device=paddle.CPUPlace()) + def test_tensor_pin_memory_and_device(self): + if core.is_compiled_with_cuda(): + tensor_res = paddle.tensor( + self.array, device="gpu", pin_memory=True + ) + self.assertEqual(tensor_res.place, core.CUDAPinnedPlace()) - np.testing.assert_array_equal( - tensor_device.numpy(), tensor_place.numpy() - ) - self.assertEqual(tensor_device.place, tensor_place.place) - - # 2. Test conflict between place and device (should raise KeyError) - with self.assertRaises(ValueError) as context: - paddle.to_tensor( - self.array, - place=paddle.CPUPlace(), - device=paddle.CPUPlace(), # Conflict + tensor_cuda = paddle.tensor(self.array, device="cuda:0") + self.assertEqual(tensor_cuda.place, paddle.CUDAPlace(0)) + + tensor_pin = paddle.tensor(self.array, device="gpu_pinned") + self.assertEqual(tensor_pin.place, core.CUDAPinnedPlace()) + + if core.is_compiled_with_xpu(): + tensor_res = paddle.tensor( + self.array, device="xpu", pin_memory=True + ) + self.assertEqual(tensor_res.place, core.XPUPinnedPlace()) + + tensor_pin = paddle.tensor(self.array, device="xpu_pinned") + self.assertEqual(tensor_pin.place, core.XPUPinnedPlace()) + + with self.assertRaises(RuntimeError) as context: + paddle.tensor( + self.array, device="cpu", pin_memory=True # no support ) self.assertIn( - "Cannot specify both 'place' and its alias 'device'", + "Pinning memory is not supported", str(context.exception), ) - # 3. Test dtype and stop_gradient consistency - tensor1 = paddle.to_tensor( - self.array, dtype="float32", device=paddle.CPUPlace() + def test_tensor_and_to_tensor(self): + """ + test tensor equal to to_tensor + """ + tensor_res = paddle.tensor( + self.array, dtype="float32", device="cpu", requires_grad=True ) - tensor2 = paddle.to_tensor( - self.array, dtype="float32", place=paddle.CPUPlace() + tensor_target = paddle.to_tensor( + self.array, dtype="float32", place="cpu", stop_gradient=False ) - - self.assertEqual(tensor1.dtype, tensor2.dtype) - self.assertEqual(tensor1.dtype, paddle.float32) - self.assertTrue(tensor1.stop_gradient) - self.assertEqual(tensor1.stop_gradient, tensor2.stop_gradient) - - # 4. Test cross-device compatibility (CPU/GPU) - for device in [paddle.CPUPlace()] + ( - [paddle.CUDAPlace(0)] if core.is_compiled_with_cuda() else [] - ): - tensor_device = paddle.to_tensor(self.array, device=device) - tensor_place = paddle.to_tensor(self.array, place=device) - - self.assertEqual(tensor_device.place, tensor_place.place) - self.assertEqual(tensor_device.place, device) + np.testing.assert_array_equal(tensor_res.numpy(), tensor_target.numpy()) + self.assertEqual(tensor_res.place, tensor_target.place) + self.assertEqual(tensor_res.place, core.CPUPlace()) + self.assertEqual(tensor_res.dtype, tensor_target.dtype) + self.assertEqual(tensor_res.dtype, paddle.float32) + self.assertEqual(tensor_res.stop_gradient, tensor_target.stop_gradient) + self.assertEqual(tensor_res.stop_gradient, False) + + def test_tensor_module(self): + """ + test paddle.tensor usable as an API and a module + """ + tensor_api = paddle.tensor(self.array, dtype="float32") + tensor_module = paddle.tensor.creation.tensor( + self.array, dtype="float32" + ) + np.testing.assert_array_equal(tensor_api.numpy(), tensor_module.numpy()) + self.assertEqual(tensor_api.place, tensor_module.place) + self.assertEqual(tensor_api.dtype, tensor_module.dtype) + self.assertEqual(tensor_api.stop_gradient, tensor_module.stop_gradient) + + def test_tensor_method_or_module(self): + """ + test the class method + """ + # __rerp__ + ori_repr = repr(paddle.tensor.creation.tensor) + now_repr = repr(paddle.tensor) + self.assertEqual(ori_repr, now_repr) + + # __str__ + ori_str = str(paddle.tensor.creation.tensor) + now_str = str(paddle.tensor) + self.assertEqual(ori_str, now_str) + + # __dir__ + api_dir = dir(paddle.tensor.creation.tensor) + module_dir = dir(paddle.tensor) + self.assertGreater(len(module_dir), len(api_dir)) def test_list_to_tensor(self): array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]] @@ -1348,7 +1383,7 @@ def test_to_tensor_from___cuda_array_interface__(self): ): x = paddle.to_tensor([1, 2, 3]) paddle.to_tensor(x) - flag = paddle.tensor.creation._warned_in_to_tensor + flag = paddle.tensor.creation._warned_in_tensor self.assertTrue(flag) def test_dlpack_device(self):