Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.15】 为 Paddle 新增 Tensor.to() 以及 Layer.astype() API -part #58244

Merged
merged 16 commits into from
Nov 9, 2023
124 changes: 124 additions & 0 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,129 @@ def transform(t, device, dtype, blocking):
warnings.filterwarnings("ignore", category=UserWarning)
return transform(self, device, dtype, blocking)

@framework.dygraph_only
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. A paddle.dtype and place
are inferred from the arguments of ``self.to(*args, **kwargs)``.There are
three ways to call `to`:

1. to(dtype, blocking=True)
2. to(device, dtype=None, blocking=True)
3. to(other, blocking=True)

Returns:
Tensor: self

Examples:
.. code-block:: python

>>> import paddle
>>> tensorx = paddle.to_tensor([1,2,3])
>>> print(tensorx)
Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3])

>>> tensorx = tensorx.to("cpu")
>>> print(tensorx.place)
Place(cpu)

>>> tensorx = tensorx.to("float32")
>>> print(tensorx.dtype)
paddle.float32

>>> tensorx = tensorx.to("gpu", "int16")
>>> print(tensorx)
Tensor(shape=[3], dtype=int16, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3])
>>> tensor2 = paddle.to_tensor([4,5,6])
>>> tensor2
Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[4, 5, 6])
>>> tensor2 = tensor2.to(tensorx)
>>> print(tensor2)
Tensor(shape=[3], dtype=int16, place=Place(gpu:0), stop_gradient=True,
[4, 5, 6])
"""
device = None
dtype = None
blocking = None
size_args = len(args)
size_kwargs = len(kwargs)

def get_device_dtype_from_tensor(other):
if other is not None:
device = str(other.place)[6:-1]
dtype = other.dtype
return device, dtype
else:
return None, None

if size_args + size_kwargs > 3 or size_args + size_kwargs == 0:
raise TypeError(
"to() received too mant arguments - expected one of:\n \
* (Union[str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace(), paddle.XPUPlace(), paddle.CustomPlace()] \
device, Union[str, paddle.dtype, numpy.dtype] dtype, bool blocking)\n \
* (Union[str, paddle.dtype, numpy.dtype] dtype, bool blocking)\n \
* (paddle.Tensor other, bool blocking) "
)
valid_keys = {"device", "dtype", "blocking", "other"}
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
invalid_keys = set(kwargs.keys()) - valid_keys
if len(invalid_keys) != 0:
raise TypeError(
"to() got an unexpected keyword argument "
+ list(invalid_keys)[0]
)
if size_args > 0:
if isinstance(args[0], paddle.Tensor):
device, dtype = get_device_dtype_from_tensor(args[0])
if size_args == 2:
blocking = args[1]
else:
blocking = kwargs.get("blocking", None)
elif (
isinstance(args[0], (paddle.dtype, np.dtype))
or isinstance(args[0], str)
and args[0].lower() in valid_dtypes
):
dtype = args[0]
if size_args == 2:
blocking = args[1]
else:
blocking = kwargs.get("blocking", None)
else:
device = args[0]
if size_args == 2:
dtype = args[1]
elif size_args == 3:
dtype, blocking = args[1], args[2]
else:
dtype = kwargs.get("dtype", None)
blocking = kwargs.get("blocking", None)
else:
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)
blocking = kwargs.get("blocking", None)
if device is None and dtype is None:
device, dtype = get_device_dtype_from_tensor(
kwargs.get("other", None)
)
return self._to(device, dtype, blocking)

@property
def grad(self):
"""
Expand Down Expand Up @@ -1020,6 +1143,7 @@ def coalesce(self, name=None):
("item", item),
("__setitem__", __setitem__),
("_to", _to),
("to", to),
("values", values),
("to_dense", to_dense),
("to_sparse_coo", to_sparse_coo),
Expand Down
80 changes: 80 additions & 0 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,86 @@ def parameters(self, include_sublayers=True):
]
return ret

def astype(self, dtype=None):
"""

Casts all parameters and buffers to dtype and then return the Layer.

Parameters:
dtype(str|paddle.dtype|numpy.dtype): target data type of layer.
If set str, it can be "bool", "bfloat16", "float16", "float32", "float64",
"int8", "int16", "int32", "int64", "uint8", "complex64", "complex128".
Default: None

Returns:
Layer, self

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.nn as nn
>>> weight_attr = paddle.ParamAttr(name="weight",initializer=paddle.nn.initializer.Constant(value=1.5))
>>> bias_attr = paddle.ParamAttr(name="bias",initializer=paddle.nn.initializer.Constant(value=2.5))

>>> linear = paddle.nn.Linear(2, 2, weight_attr=weight_attr, bias_attr=bias_attr).to(device="cpu",dtype="float32")
>>> print(linear)
Linear(in_features=2, out_features=2, dtype=float32)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[1.50000000, 1.50000000],
[1.50000000, 1.50000000]]), Parameter containing:
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=False,
[2.50000000, 2.50000000])]

>>> linear=linear.astype("int8")
>>> print(linear)
Linear(in_features=2, out_features=2, dtype=paddle.int8)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[2, 2], dtype=int8, place=Place(cpu), stop_gradient=False,
[[1, 1],
[1, 1]]), Parameter containing:
Tensor(shape=[2], dtype=int8, place=Place(cpu), stop_gradient=False,
[2, 2])]

"""
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
if (
isinstance(dtype, (paddle.dtype, np.dtype))
or type(dtype) is str
and dtype in valid_dtypes
):
if isinstance(dtype, (str, np.dtype)):
dtype = framework.convert_np_dtype_to_dtype_(dtype)
self._dtype = dtype
for layer in self.sublayers():
layer._dtype = dtype
for _, param in self.named_parameters(include_sublayers=True):
param._to(None, dtype)
for _, buffer in self.named_buffers(include_sublayers=True):
buffer.to(None, dtype)
return self
else:
raise ValueError(
"dtype value error, must be 'bfloat16', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', 'bool', or paddle.dtype, numpy.dtype, but recieve "
+ str(dtype)
)

def children(self):
"""

Expand Down
147 changes: 147 additions & 0 deletions test/legacy_test/test_Tensor_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle import base


class TensorToTest(unittest.TestCase):
def test_Tensor_to_dtype(self):
tensorx = paddle.to_tensor([1, 2, 3])
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
for dtype in valid_dtypes:
tensorx = tensorx.to(dtype)
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle." + dtype)

def test_Tensor_to_device(self):
tensorx = paddle.to_tensor([1, 2, 3])
places = ["cpu"]
if base.core.is_compiled_with_cuda():
places.append("gpu:0")
places.append("gpu")

for place in places:
tensorx = tensorx.to(place)
placex_str = str(tensorx.place)
if place == "gpu":
self.assertTrue(placex_str, "Place(" + place + ":0)")
else:
self.assertTrue(placex_str, "Place(" + place + ")")

def test_Tensor_to_device_dtype(self):
tensorx = paddle.to_tensor([1, 2, 3])
places = ["cpu"]
if base.core.is_compiled_with_cuda():
places.append("gpu:0")
places.append("gpu")
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
for dtype in valid_dtypes:
for place in places:
tensorx = tensorx.to(place, dtype)
placex_str = str(tensorx.place)
if place == "gpu":
self.assertTrue(placex_str, "Place(" + place + ":0)")
else:
self.assertTrue(placex_str, "Place(" + place + ")")
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle." + dtype)

def test_Tensor_to_blocking(self):
tensorx = paddle.to_tensor([1, 2, 3])
tensorx = tensorx.to("cpu", "int32", False)
placex_str = str(tensorx.place)
self.assertTrue(placex_str, "Place(cpu)")
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle.int32")
tensor2 = paddle.to_tensor([4, 5, 6])
tensor2 = tensor2.to(tensorx, False)
place2_str = str(tensor2.place)
self.assertTrue(place2_str, "Place(cpu)")
type2_str = str(tensor2.dtype)
self.assertTrue(type2_str, "paddle.int32")
tensor2 = tensor2.to("float16", False)
type2_str = str(tensor2.dtype)
self.assertTrue(type2_str, "paddle.float16")

def test_Tensor_to_other(self):
tensor1 = paddle.to_tensor([1, 2, 3], dtype="int8", place="cpu")
tensor2 = paddle.to_tensor([1, 2, 3])
tensor2 = tensor2.to(tensor1)
self.assertTrue(tensor2.dtype, tensor1.dtype)
self.assertTrue(type(tensor2.place), type(tensor1.place))

def test_kwargs(self):
tensorx = paddle.to_tensor([1, 2, 3])
tensorx = tensorx.to(device="cpu", dtype="int8", blocking=True)
placex_str = str(tensorx.place)
self.assertTrue(placex_str, "Place(cpu)")
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle.int8")
tensor2 = paddle.to_tensor([4, 5, 6])
tensor2 = tensor2.to(other=tensorx)
place2_str = str(tensor2.place)
self.assertTrue(place2_str, "Place(cpu)")
type2_str = str(tensor2.dtype)
self.assertTrue(type2_str, "paddle.int8")

def test_error(self):
tensorx = paddle.to_tensor([1, 2, 3])
# device value error
try:
tensorx = tensorx.to("error_device")
except Exception as error:
self.assertIsInstance(error, ValueError)
# to many augments
try:
tensorx = tensorx.to("cpu", "int32", False, "test_aug")
except Exception as error:
self.assertIsInstance(error, TypeError)
# invalid key
try:
tensorx = tensorx.to("cpu", "int32", test_key=False)
except Exception as error:
self.assertIsInstance(error, TypeError)


if __name__ == '__main__':
unittest.main()
Loading